Shengxiao0709 commited on
Commit
8f72b1f
·
verified ·
1 Parent(s): 1978e37

Upload 78 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +2 -0
  2. _utils/attn_utils.py +592 -0
  3. _utils/attn_utils_new.py +610 -0
  4. _utils/config.yaml +15 -0
  5. _utils/example_config.yaml +20 -0
  6. _utils/load_models.py +16 -0
  7. _utils/load_track_data.py +104 -0
  8. _utils/misc_helper.py +37 -0
  9. _utils/seg_eval.py +61 -0
  10. _utils/track_args.py +157 -0
  11. config.py +44 -0
  12. counting.py +337 -0
  13. example_imgs/1977_Well_F-5_Field_1.png +3 -0
  14. example_imgs/1977_Well_F-5_Field_1_seg.png +3 -0
  15. models/.DS_Store +0 -0
  16. models/enc_model/__init__.py +0 -0
  17. models/enc_model/backbone.py +64 -0
  18. models/enc_model/loca.py +232 -0
  19. models/enc_model/loca_args.py +44 -0
  20. models/enc_model/mlp.py +23 -0
  21. models/enc_model/ope.py +245 -0
  22. models/enc_model/positional_encoding.py +30 -0
  23. models/enc_model/regression_head.py +92 -0
  24. models/enc_model/transformer.py +94 -0
  25. models/enc_model/unet_parts.py +77 -0
  26. models/model.py +991 -0
  27. models/seg_post_model/cellpose/__init__.py +1 -0
  28. models/seg_post_model/cellpose/__main__.py +272 -0
  29. models/seg_post_model/cellpose/cli.py +240 -0
  30. models/seg_post_model/cellpose/core.py +322 -0
  31. models/seg_post_model/cellpose/denoise.py +1474 -0
  32. models/seg_post_model/cellpose/dynamics.py +691 -0
  33. models/seg_post_model/cellpose/export.py +405 -0
  34. models/seg_post_model/cellpose/gui/gui.py +2007 -0
  35. models/seg_post_model/cellpose/gui/gui3d.py +667 -0
  36. models/seg_post_model/cellpose/gui/guihelpwindowtext.html +143 -0
  37. models/seg_post_model/cellpose/gui/guiparts.py +793 -0
  38. models/seg_post_model/cellpose/gui/guitrainhelpwindowtext.html +25 -0
  39. models/seg_post_model/cellpose/gui/io.py +634 -0
  40. models/seg_post_model/cellpose/gui/make_train.py +107 -0
  41. models/seg_post_model/cellpose/gui/menus.py +145 -0
  42. models/seg_post_model/cellpose/io.py +816 -0
  43. models/seg_post_model/cellpose/metrics.py +205 -0
  44. models/seg_post_model/cellpose/models.py +524 -0
  45. models/seg_post_model/cellpose/plot.py +281 -0
  46. models/seg_post_model/cellpose/transforms.py +1261 -0
  47. models/seg_post_model/cellpose/utils.py +667 -0
  48. models/seg_post_model/cellpose/version.py +18 -0
  49. models/seg_post_model/cellpose/vit_sam.py +195 -0
  50. models/seg_post_model/cellpose/vit_sam_new.py +197 -0
.gitattributes CHANGED
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  003_img.png filter=lfs diff=lfs merge=lfs -text
37
  1977_Well_F-5_Field_1.png filter=lfs diff=lfs merge=lfs -text
 
 
 
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
  003_img.png filter=lfs diff=lfs merge=lfs -text
37
  1977_Well_F-5_Field_1.png filter=lfs diff=lfs merge=lfs -text
38
+ example_imgs/1977_Well_F-5_Field_1_seg.png filter=lfs diff=lfs merge=lfs -text
39
+ example_imgs/1977_Well_F-5_Field_1.png filter=lfs diff=lfs merge=lfs -text
_utils/attn_utils.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from IPython.display import display
7
+ from PIL import Image
8
+ from typing import Union, Tuple, List
9
+ from einops import rearrange, repeat
10
+ import math
11
+ from torch import nn, einsum
12
+ from inspect import isfunction
13
+ from diffusers.utils import logging
14
+ try:
15
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
16
+ except:
17
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
18
+
19
+ try:
20
+ from diffusers.models.cross_attention import CrossAttention
21
+ except:
22
+ from diffusers.models.attention_processor import Attention as CrossAttention
23
+
24
+ MAX_NUM_WORDS = 77
25
+ LOW_RESOURCE = False
26
+
27
+ class CountingCrossAttnProcessor1:
28
+
29
+ def __init__(self, attnstore, place_in_unet):
30
+ super().__init__()
31
+ self.attnstore = attnstore
32
+ self.place_in_unet = place_in_unet
33
+
34
+ def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
35
+ batch_size, sequence_length, dim = hidden_states.shape
36
+ h = attn_layer.heads
37
+ q = attn_layer.to_q(hidden_states)
38
+ is_cross = encoder_hidden_states is not None
39
+ context = encoder_hidden_states if is_cross else hidden_states
40
+ k = attn_layer.to_k(context)
41
+ v = attn_layer.to_v(context)
42
+ # q = attn_layer.reshape_heads_to_batch_dim(q)
43
+ # k = attn_layer.reshape_heads_to_batch_dim(k)
44
+ # v = attn_layer.reshape_heads_to_batch_dim(v)
45
+ # q = attn_layer.head_to_batch_dim(q)
46
+ # k = attn_layer.head_to_batch_dim(k)
47
+ # v = attn_layer.head_to_batch_dim(v)
48
+ q = self.head_to_batch_dim(q, h)
49
+ k = self.head_to_batch_dim(k, h)
50
+ v = self.head_to_batch_dim(v, h)
51
+
52
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale
53
+
54
+ if attention_mask is not None:
55
+ attention_mask = attention_mask.reshape(batch_size, -1)
56
+ max_neg_value = -torch.finfo(sim.dtype).max
57
+ attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
58
+ sim.masked_fill_(~attention_mask, max_neg_value)
59
+
60
+ # attention, what we cannot get enough of
61
+ attn_ = sim.softmax(dim=-1).clone()
62
+ # softmax = nn.Softmax(dim=-1)
63
+ # attn_ = softmax(sim)
64
+ self.attnstore(attn_, is_cross, self.place_in_unet)
65
+ out = torch.einsum("b i j, b j d -> b i d", attn_, v)
66
+ # out = attn_layer.batch_to_head_dim(out)
67
+ out = self.batch_to_head_dim(out, h)
68
+
69
+ if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
70
+ to_out = attn_layer.to_out[0]
71
+ else:
72
+ to_out = attn_layer.to_out
73
+
74
+ out = to_out(out)
75
+ return out
76
+
77
+ def batch_to_head_dim(self, tensor, head_size):
78
+ # head_size = self.heads
79
+ batch_size, seq_len, dim = tensor.shape
80
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
81
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
82
+ return tensor
83
+
84
+ def head_to_batch_dim(self, tensor, head_size, out_dim=3):
85
+ # head_size = self.heads
86
+ batch_size, seq_len, dim = tensor.shape
87
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
88
+ tensor = tensor.permute(0, 2, 1, 3)
89
+
90
+ if out_dim == 3:
91
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
92
+
93
+ return tensor
94
+
95
+
96
+ def register_attention_control(model, controller):
97
+
98
+ attn_procs = {}
99
+ cross_att_count = 0
100
+ for name in model.unet.attn_processors.keys():
101
+ cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
102
+ if name.startswith("mid_block"):
103
+ hidden_size = model.unet.config.block_out_channels[-1]
104
+ place_in_unet = "mid"
105
+ elif name.startswith("up_blocks"):
106
+ block_id = int(name[len("up_blocks.")])
107
+ hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
108
+ place_in_unet = "up"
109
+ elif name.startswith("down_blocks"):
110
+ block_id = int(name[len("down_blocks.")])
111
+ hidden_size = model.unet.config.block_out_channels[block_id]
112
+ place_in_unet = "down"
113
+ else:
114
+ continue
115
+
116
+ cross_att_count += 1
117
+ # attn_procs[name] = AttendExciteCrossAttnProcessor(
118
+ # attnstore=controller, place_in_unet=place_in_unet
119
+ # )
120
+ attn_procs[name] = CountingCrossAttnProcessor1(
121
+ attnstore=controller, place_in_unet=place_in_unet
122
+ )
123
+
124
+ model.unet.set_attn_processor(attn_procs)
125
+ controller.num_att_layers = cross_att_count
126
+
127
+ def register_hier_output(model):
128
+ self = model.unet
129
+ from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding
130
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
131
+ def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None,
132
+ attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None,
133
+ mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True):
134
+
135
+ out_list = []
136
+
137
+
138
+ default_overall_up_factor = 2**self.num_upsamplers
139
+
140
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
141
+ forward_upsample_size = False
142
+ upsample_size = None
143
+
144
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
145
+ logger.info("Forward upsample size to force interpolation output size.")
146
+ forward_upsample_size = True
147
+
148
+ if attention_mask is not None:
149
+ # assume that mask is expressed as:
150
+ # (1 = keep, 0 = discard)
151
+ # convert mask into a bias that can be added to attention scores:
152
+ # (keep = +0, discard = -10000.0)
153
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
154
+ attention_mask = attention_mask.unsqueeze(1)
155
+
156
+ if encoder_attention_mask is not None:
157
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
158
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
159
+
160
+ if self.config.center_input_sample:
161
+ sample = 2 * sample - 1.0
162
+
163
+ timesteps = timestep
164
+ if not torch.is_tensor(timesteps):
165
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
166
+ # This would be a good case for the `match` statement (Python 3.10+)
167
+ is_mps = sample.device.type == "mps"
168
+ if isinstance(timestep, float):
169
+ dtype = torch.float32 if is_mps else torch.float64
170
+ else:
171
+ dtype = torch.int32 if is_mps else torch.int64
172
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
173
+ elif len(timesteps.shape) == 0:
174
+ timesteps = timesteps[None].to(sample.device)
175
+
176
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
177
+ timesteps = timesteps.expand(sample.shape[0])
178
+
179
+ t_emb = self.time_proj(timesteps)
180
+
181
+ t_emb = t_emb.to(dtype=sample.dtype)
182
+
183
+ emb = self.time_embedding(t_emb, timestep_cond)
184
+ aug_emb = None
185
+
186
+ if self.class_embedding is not None:
187
+ if class_labels is None:
188
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
189
+
190
+ if self.config.class_embed_type == "timestep":
191
+ class_labels = self.time_proj(class_labels)
192
+
193
+ # `Timesteps` does not contain any weights and will always return f32 tensors
194
+ # there might be better ways to encapsulate this.
195
+ class_labels = class_labels.to(dtype=sample.dtype)
196
+
197
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
198
+
199
+ if self.config.class_embeddings_concat:
200
+ emb = torch.cat([emb, class_emb], dim=-1)
201
+ else:
202
+ emb = emb + class_emb
203
+
204
+ if self.config.addition_embed_type == "text":
205
+ aug_emb = self.add_embedding(encoder_hidden_states)
206
+ elif self.config.addition_embed_type == "text_image":
207
+ # Kandinsky 2.1 - style
208
+ if "image_embeds" not in added_cond_kwargs:
209
+ raise ValueError(
210
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
211
+ )
212
+
213
+ image_embs = added_cond_kwargs.get("image_embeds")
214
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
215
+ aug_emb = self.add_embedding(text_embs, image_embs)
216
+ elif self.config.addition_embed_type == "text_time":
217
+ # SDXL - style
218
+ if "text_embeds" not in added_cond_kwargs:
219
+ raise ValueError(
220
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
221
+ )
222
+ text_embeds = added_cond_kwargs.get("text_embeds")
223
+ if "time_ids" not in added_cond_kwargs:
224
+ raise ValueError(
225
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
226
+ )
227
+ time_ids = added_cond_kwargs.get("time_ids")
228
+ time_embeds = self.add_time_proj(time_ids.flatten())
229
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
230
+
231
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
232
+ add_embeds = add_embeds.to(emb.dtype)
233
+ aug_emb = self.add_embedding(add_embeds)
234
+ elif self.config.addition_embed_type == "image":
235
+ # Kandinsky 2.2 - style
236
+ if "image_embeds" not in added_cond_kwargs:
237
+ raise ValueError(
238
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
239
+ )
240
+ image_embs = added_cond_kwargs.get("image_embeds")
241
+ aug_emb = self.add_embedding(image_embs)
242
+ elif self.config.addition_embed_type == "image_hint":
243
+ # Kandinsky 2.2 - style
244
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
245
+ raise ValueError(
246
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
247
+ )
248
+ image_embs = added_cond_kwargs.get("image_embeds")
249
+ hint = added_cond_kwargs.get("hint")
250
+ aug_emb, hint = self.add_embedding(image_embs, hint)
251
+ sample = torch.cat([sample, hint], dim=1)
252
+
253
+ emb = emb + aug_emb if aug_emb is not None else emb
254
+
255
+ if self.time_embed_act is not None:
256
+ emb = self.time_embed_act(emb)
257
+
258
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
259
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
260
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
261
+ # Kadinsky 2.1 - style
262
+ if "image_embeds" not in added_cond_kwargs:
263
+ raise ValueError(
264
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
265
+ )
266
+
267
+ image_embeds = added_cond_kwargs.get("image_embeds")
268
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
269
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
270
+ # Kandinsky 2.2 - style
271
+ if "image_embeds" not in added_cond_kwargs:
272
+ raise ValueError(
273
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
274
+ )
275
+ image_embeds = added_cond_kwargs.get("image_embeds")
276
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
277
+ # 2. pre-process
278
+ sample = self.conv_in(sample) # 1, 320, 64, 64
279
+
280
+ # 2.5 GLIGEN position net
281
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
282
+ cross_attention_kwargs = cross_attention_kwargs.copy()
283
+ gligen_args = cross_attention_kwargs.pop("gligen")
284
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
285
+
286
+ # 3. down
287
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
288
+
289
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
290
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
291
+
292
+ down_block_res_samples = (sample,)
293
+
294
+ for downsample_block in self.down_blocks:
295
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
296
+ # For t2i-adapter CrossAttnDownBlock2D
297
+ additional_residuals = {}
298
+ if is_adapter and len(down_block_additional_residuals) > 0:
299
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
300
+
301
+ sample, res_samples = downsample_block(
302
+ hidden_states=sample,
303
+ temb=emb,
304
+ encoder_hidden_states=encoder_hidden_states,
305
+ attention_mask=attention_mask,
306
+ cross_attention_kwargs=cross_attention_kwargs,
307
+ encoder_attention_mask=encoder_attention_mask,
308
+ **additional_residuals,
309
+ )
310
+ else:
311
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
312
+
313
+ if is_adapter and len(down_block_additional_residuals) > 0:
314
+ sample += down_block_additional_residuals.pop(0)
315
+
316
+ down_block_res_samples += res_samples
317
+
318
+ if is_controlnet:
319
+ new_down_block_res_samples = ()
320
+
321
+ for down_block_res_sample, down_block_additional_residual in zip(
322
+ down_block_res_samples, down_block_additional_residuals
323
+ ):
324
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
325
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
326
+
327
+ down_block_res_samples = new_down_block_res_samples
328
+
329
+ # 4. mid
330
+ if self.mid_block is not None:
331
+ sample = self.mid_block(
332
+ sample,
333
+ emb,
334
+ encoder_hidden_states=encoder_hidden_states,
335
+ attention_mask=attention_mask,
336
+ cross_attention_kwargs=cross_attention_kwargs,
337
+ encoder_attention_mask=encoder_attention_mask,
338
+ )
339
+ # To support T2I-Adapter-XL
340
+ if (
341
+ is_adapter
342
+ and len(down_block_additional_residuals) > 0
343
+ and sample.shape == down_block_additional_residuals[0].shape
344
+ ):
345
+ sample += down_block_additional_residuals.pop(0)
346
+
347
+ if is_controlnet:
348
+ sample = sample + mid_block_additional_residual
349
+
350
+ # 5. up
351
+ for i, upsample_block in enumerate(self.up_blocks):
352
+ is_final_block = i == len(self.up_blocks) - 1
353
+
354
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
355
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
356
+
357
+ # if we have not reached the final block and need to forward the
358
+ # upsample size, we do it here
359
+ if not is_final_block and forward_upsample_size:
360
+ upsample_size = down_block_res_samples[-1].shape[2:]
361
+
362
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
363
+ sample = upsample_block(
364
+ hidden_states=sample,
365
+ temb=emb,
366
+ res_hidden_states_tuple=res_samples,
367
+ encoder_hidden_states=encoder_hidden_states,
368
+ cross_attention_kwargs=cross_attention_kwargs,
369
+ upsample_size=upsample_size,
370
+ attention_mask=attention_mask,
371
+ encoder_attention_mask=encoder_attention_mask,
372
+ )
373
+ else:
374
+ sample = upsample_block(
375
+ hidden_states=sample,
376
+ temb=emb,
377
+ res_hidden_states_tuple=res_samples,
378
+ upsample_size=upsample_size,
379
+ scale=lora_scale,
380
+ )
381
+
382
+ # if i in [1, 4, 7]:
383
+ out_list.append(sample)
384
+
385
+ # 6. post-process
386
+ if self.conv_norm_out:
387
+ sample = self.conv_norm_out(sample)
388
+ sample = self.conv_act(sample)
389
+ sample = self.conv_out(sample)
390
+
391
+ if not return_dict:
392
+ return (sample,)
393
+
394
+ return UNet2DConditionOutput(sample=sample), out_list
395
+
396
+ self.forward = forward
397
+
398
+
399
+ class AttentionControl(abc.ABC):
400
+
401
+ def step_callback(self, x_t):
402
+ return x_t
403
+
404
+ def between_steps(self):
405
+ return
406
+
407
+ @property
408
+ def num_uncond_att_layers(self):
409
+ return 0
410
+
411
+ @abc.abstractmethod
412
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
413
+ raise NotImplementedError
414
+
415
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
416
+ if self.cur_att_layer >= self.num_uncond_att_layers:
417
+ # self.forward(attn, is_cross, place_in_unet)
418
+ if LOW_RESOURCE:
419
+ attn = self.forward(attn, is_cross, place_in_unet)
420
+ else:
421
+ h = attn.shape[0]
422
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
423
+ self.cur_att_layer += 1
424
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
425
+ self.cur_att_layer = 0
426
+ self.cur_step += 1
427
+ self.between_steps()
428
+ return attn
429
+
430
+ def reset(self):
431
+ self.cur_step = 0
432
+ self.cur_att_layer = 0
433
+
434
+ def __init__(self):
435
+ self.cur_step = 0
436
+ self.num_att_layers = -1
437
+ self.cur_att_layer = 0
438
+
439
+
440
+ class EmptyControl(AttentionControl):
441
+
442
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
443
+ return attn
444
+
445
+
446
+ class AttentionStore(AttentionControl):
447
+
448
+ @staticmethod
449
+ def get_empty_store():
450
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
451
+ "down_self": [], "mid_self": [], "up_self": []}
452
+
453
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
454
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
455
+ if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead
456
+ self.step_store[key].append(attn)
457
+ return attn
458
+
459
+ def between_steps(self):
460
+ self.attention_store = self.step_store
461
+ if self.save_global_store:
462
+ with torch.no_grad():
463
+ if len(self.global_store) == 0:
464
+ self.global_store = self.step_store
465
+ else:
466
+ for key in self.global_store:
467
+ for i in range(len(self.global_store[key])):
468
+ self.global_store[key][i] += self.step_store[key][i].detach()
469
+ self.step_store = self.get_empty_store()
470
+ self.step_store = self.get_empty_store()
471
+
472
+ def get_average_attention(self):
473
+ average_attention = self.attention_store
474
+ return average_attention
475
+
476
+ def get_average_global_attention(self):
477
+ average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
478
+ self.attention_store}
479
+ return average_attention
480
+
481
+ def reset(self):
482
+ super(AttentionStore, self).reset()
483
+ self.step_store = self.get_empty_store()
484
+ self.attention_store = {}
485
+ self.global_store = {}
486
+
487
+ def __init__(self, max_size=32, save_global_store=False):
488
+ '''
489
+ Initialize an empty AttentionStore
490
+ :param step_index: used to visualize only a specific step in the diffusion process
491
+ '''
492
+ super(AttentionStore, self).__init__()
493
+ self.save_global_store = save_global_store
494
+ self.max_size = max_size
495
+ self.step_store = self.get_empty_store()
496
+ self.attention_store = {}
497
+ self.global_store = {}
498
+ self.curr_step_index = 0
499
+
500
+ def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
501
+ out = []
502
+ attention_maps = attention_store.get_average_attention()
503
+ num_pixels = res ** 2
504
+ for location in from_where:
505
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
506
+ if item.shape[1] == num_pixels:
507
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
508
+ out.append(cross_maps)
509
+ out = torch.cat(out, dim=0)
510
+ out = out.sum(0) / out.shape[0]
511
+ return out
512
+
513
+
514
+ def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
515
+ tokens = tokenizer.encode(prompts[select])
516
+ decoder = tokenizer.decode
517
+ attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
518
+ images = []
519
+ for i in range(len(tokens)):
520
+ image = attention_maps[:, :, i]
521
+ image = 255 * image / image.max()
522
+ image = image.unsqueeze(-1).expand(*image.shape, 3)
523
+ image = image.numpy().astype(np.uint8)
524
+ image = np.array(Image.fromarray(image).resize((256, 256)))
525
+ image = text_under_image(image, decoder(int(tokens[i])))
526
+ images.append(image)
527
+ view_images(np.stack(images, axis=0))
528
+
529
+
530
+ def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
531
+ max_com=10, select: int = 0):
532
+ attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
533
+ u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
534
+ images = []
535
+ for i in range(max_com):
536
+ image = vh[i].reshape(res, res)
537
+ image = image - image.min()
538
+ image = 255 * image / image.max()
539
+ image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
540
+ image = Image.fromarray(image).resize((256, 256))
541
+ image = np.array(image)
542
+ images.append(image)
543
+ view_images(np.concatenate(images, axis=1))
544
+
545
+ def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
546
+ h, w, c = image.shape
547
+ offset = int(h * .2)
548
+ img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
549
+ font = cv2.FONT_HERSHEY_SIMPLEX
550
+ # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
551
+ img[:h] = image
552
+ textsize = cv2.getTextSize(text, font, 1, 2)[0]
553
+ text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
554
+ cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
555
+ return img
556
+
557
+
558
+ def view_images(images, num_rows=1, offset_ratio=0.02):
559
+ if type(images) is list:
560
+ num_empty = len(images) % num_rows
561
+ elif images.ndim == 4:
562
+ num_empty = images.shape[0] % num_rows
563
+ else:
564
+ images = [images]
565
+ num_empty = 0
566
+
567
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
568
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
569
+ num_items = len(images)
570
+
571
+ h, w, c = images[0].shape
572
+ offset = int(h * offset_ratio)
573
+ num_cols = num_items // num_rows
574
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
575
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
576
+ for i in range(num_rows):
577
+ for j in range(num_cols):
578
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
579
+ i * num_cols + j]
580
+
581
+ pil_img = Image.fromarray(image_)
582
+ display(pil_img)
583
+
584
+ def self_cross_attn(self_attn, cross_attn):
585
+ res = self_attn.shape[0]
586
+ assert res == cross_attn.shape[0]
587
+ # cross attn [res, res] -> [res*res]
588
+ cross_attn_ = cross_attn.reshape([res*res])
589
+ # self_attn [res, res, res*res]
590
+ self_cross_attn = cross_attn_ * self_attn
591
+ self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0)
592
+ return self_cross_attn
_utils/attn_utils_new.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from IPython.display import display
7
+ from PIL import Image
8
+ from typing import Union, Tuple, List
9
+ from einops import rearrange, repeat
10
+ import math
11
+ from torch import nn, einsum
12
+ from inspect import isfunction
13
+ from diffusers.utils import logging
14
+ try:
15
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
16
+ except:
17
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
18
+ try:
19
+ from diffusers.models.cross_attention import CrossAttention
20
+ except:
21
+ from diffusers.models.attention_processor import Attention as CrossAttention
22
+ from typing import Any, Dict, List, Optional, Tuple, Union
23
+ MAX_NUM_WORDS = 77
24
+ LOW_RESOURCE = False
25
+
26
+ class CountingCrossAttnProcessor1:
27
+
28
+ def __init__(self, attnstore, place_in_unet):
29
+ super().__init__()
30
+ self.attnstore = attnstore
31
+ self.place_in_unet = place_in_unet
32
+
33
+ def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
34
+ batch_size, sequence_length, dim = hidden_states.shape
35
+ h = attn_layer.heads
36
+ q = attn_layer.to_q(hidden_states)
37
+ is_cross = encoder_hidden_states is not None
38
+ context = encoder_hidden_states if is_cross else hidden_states
39
+ k = attn_layer.to_k(context)
40
+ v = attn_layer.to_v(context)
41
+ # q = attn_layer.reshape_heads_to_batch_dim(q)
42
+ # k = attn_layer.reshape_heads_to_batch_dim(k)
43
+ # v = attn_layer.reshape_heads_to_batch_dim(v)
44
+ # q = attn_layer.head_to_batch_dim(q)
45
+ # k = attn_layer.head_to_batch_dim(k)
46
+ # v = attn_layer.head_to_batch_dim(v)
47
+ q = self.head_to_batch_dim(q, h)
48
+ k = self.head_to_batch_dim(k, h)
49
+ v = self.head_to_batch_dim(v, h)
50
+
51
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale
52
+
53
+ if attention_mask is not None:
54
+ attention_mask = attention_mask.reshape(batch_size, -1)
55
+ max_neg_value = -torch.finfo(sim.dtype).max
56
+ attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
57
+ sim.masked_fill_(~attention_mask, max_neg_value)
58
+
59
+ # attention, what we cannot get enough of
60
+ attn_ = sim.softmax(dim=-1).clone()
61
+ # softmax = nn.Softmax(dim=-1)
62
+ # attn_ = softmax(sim)
63
+ self.attnstore(attn_, is_cross, self.place_in_unet)
64
+ out = torch.einsum("b i j, b j d -> b i d", attn_, v)
65
+ # out = attn_layer.batch_to_head_dim(out)
66
+ out = self.batch_to_head_dim(out, h)
67
+
68
+ if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
69
+ to_out = attn_layer.to_out[0]
70
+ else:
71
+ to_out = attn_layer.to_out
72
+
73
+ out = to_out(out)
74
+ return out
75
+
76
+ def batch_to_head_dim(self, tensor, head_size):
77
+ # head_size = self.heads
78
+ batch_size, seq_len, dim = tensor.shape
79
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
80
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
81
+ return tensor
82
+
83
+ def head_to_batch_dim(self, tensor, head_size, out_dim=3):
84
+ # head_size = self.heads
85
+ batch_size, seq_len, dim = tensor.shape
86
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
87
+ tensor = tensor.permute(0, 2, 1, 3)
88
+
89
+ if out_dim == 3:
90
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
91
+
92
+ return tensor
93
+
94
+
95
+ def register_attention_control(model, controller):
96
+
97
+ attn_procs = {}
98
+ cross_att_count = 0
99
+ for name in model.unet.attn_processors.keys():
100
+ cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
101
+ if name.startswith("mid_block"):
102
+ hidden_size = model.unet.config.block_out_channels[-1]
103
+ place_in_unet = "mid"
104
+ elif name.startswith("up_blocks"):
105
+ block_id = int(name[len("up_blocks.")])
106
+ hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
107
+ place_in_unet = "up"
108
+ elif name.startswith("down_blocks"):
109
+ block_id = int(name[len("down_blocks.")])
110
+ hidden_size = model.unet.config.block_out_channels[block_id]
111
+ place_in_unet = "down"
112
+ else:
113
+ continue
114
+
115
+ cross_att_count += 1
116
+ # attn_procs[name] = AttendExciteCrossAttnProcessor(
117
+ # attnstore=controller, place_in_unet=place_in_unet
118
+ # )
119
+ attn_procs[name] = CountingCrossAttnProcessor1(
120
+ attnstore=controller, place_in_unet=place_in_unet
121
+ )
122
+
123
+ model.unet.set_attn_processor(attn_procs)
124
+ controller.num_att_layers = cross_att_count
125
+
126
+ def register_hier_output(model):
127
+ self = model.unet
128
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
129
+ def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None,
130
+ attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None,
131
+ mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True):
132
+
133
+ out_list = []
134
+
135
+
136
+ default_overall_up_factor = 2**self.num_upsamplers
137
+
138
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
139
+ forward_upsample_size = False
140
+ upsample_size = None
141
+
142
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
143
+ logger.info("Forward upsample size to force interpolation output size.")
144
+ forward_upsample_size = True
145
+
146
+ if attention_mask is not None:
147
+ # assume that mask is expressed as:
148
+ # (1 = keep, 0 = discard)
149
+ # convert mask into a bias that can be added to attention scores:
150
+ # (keep = +0, discard = -10000.0)
151
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
152
+ attention_mask = attention_mask.unsqueeze(1)
153
+
154
+ if encoder_attention_mask is not None:
155
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
156
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
157
+
158
+ if self.config.center_input_sample:
159
+ sample = 2 * sample - 1.0
160
+
161
+ timesteps = timestep
162
+ if not torch.is_tensor(timesteps):
163
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
164
+ # This would be a good case for the `match` statement (Python 3.10+)
165
+ is_mps = sample.device.type == "mps"
166
+ if isinstance(timestep, float):
167
+ dtype = torch.float32 if is_mps else torch.float64
168
+ else:
169
+ dtype = torch.int32 if is_mps else torch.int64
170
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
171
+ elif len(timesteps.shape) == 0:
172
+ timesteps = timesteps[None].to(sample.device)
173
+
174
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
175
+ timesteps = timesteps.expand(sample.shape[0])
176
+
177
+ t_emb = self.time_proj(timesteps)
178
+
179
+ t_emb = t_emb.to(dtype=sample.dtype)
180
+
181
+ emb = self.time_embedding(t_emb, timestep_cond)
182
+ aug_emb = None
183
+
184
+ if self.class_embedding is not None:
185
+ if class_labels is None:
186
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
187
+
188
+ if self.config.class_embed_type == "timestep":
189
+ class_labels = self.time_proj(class_labels)
190
+
191
+ # `Timesteps` does not contain any weights and will always return f32 tensors
192
+ # there might be better ways to encapsulate this.
193
+ class_labels = class_labels.to(dtype=sample.dtype)
194
+
195
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
196
+
197
+ if self.config.class_embeddings_concat:
198
+ emb = torch.cat([emb, class_emb], dim=-1)
199
+ else:
200
+ emb = emb + class_emb
201
+
202
+ if self.config.addition_embed_type == "text":
203
+ aug_emb = self.add_embedding(encoder_hidden_states)
204
+ elif self.config.addition_embed_type == "text_image":
205
+ # Kandinsky 2.1 - style
206
+ if "image_embeds" not in added_cond_kwargs:
207
+ raise ValueError(
208
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
209
+ )
210
+
211
+ image_embs = added_cond_kwargs.get("image_embeds")
212
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
213
+ aug_emb = self.add_embedding(text_embs, image_embs)
214
+ elif self.config.addition_embed_type == "text_time":
215
+ # SDXL - style
216
+ if "text_embeds" not in added_cond_kwargs:
217
+ raise ValueError(
218
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
219
+ )
220
+ text_embeds = added_cond_kwargs.get("text_embeds")
221
+ if "time_ids" not in added_cond_kwargs:
222
+ raise ValueError(
223
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
224
+ )
225
+ time_ids = added_cond_kwargs.get("time_ids")
226
+ time_embeds = self.add_time_proj(time_ids.flatten())
227
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
228
+
229
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
230
+ add_embeds = add_embeds.to(emb.dtype)
231
+ aug_emb = self.add_embedding(add_embeds)
232
+ elif self.config.addition_embed_type == "image":
233
+ # Kandinsky 2.2 - style
234
+ if "image_embeds" not in added_cond_kwargs:
235
+ raise ValueError(
236
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
237
+ )
238
+ image_embs = added_cond_kwargs.get("image_embeds")
239
+ aug_emb = self.add_embedding(image_embs)
240
+ elif self.config.addition_embed_type == "image_hint":
241
+ # Kandinsky 2.2 - style
242
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
243
+ raise ValueError(
244
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
245
+ )
246
+ image_embs = added_cond_kwargs.get("image_embeds")
247
+ hint = added_cond_kwargs.get("hint")
248
+ aug_emb, hint = self.add_embedding(image_embs, hint)
249
+ sample = torch.cat([sample, hint], dim=1)
250
+
251
+ emb = emb + aug_emb if aug_emb is not None else emb
252
+
253
+ if self.time_embed_act is not None:
254
+ emb = self.time_embed_act(emb)
255
+
256
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
257
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
258
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
259
+ # Kadinsky 2.1 - style
260
+ if "image_embeds" not in added_cond_kwargs:
261
+ raise ValueError(
262
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
263
+ )
264
+
265
+ image_embeds = added_cond_kwargs.get("image_embeds")
266
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
267
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
268
+ # Kandinsky 2.2 - style
269
+ if "image_embeds" not in added_cond_kwargs:
270
+ raise ValueError(
271
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
272
+ )
273
+ image_embeds = added_cond_kwargs.get("image_embeds")
274
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
275
+ # 2. pre-process
276
+ sample = self.conv_in(sample) # 1, 320, 64, 64
277
+
278
+ # 2.5 GLIGEN position net
279
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
280
+ cross_attention_kwargs = cross_attention_kwargs.copy()
281
+ gligen_args = cross_attention_kwargs.pop("gligen")
282
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
283
+
284
+ # 3. down
285
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
286
+
287
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
288
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
289
+
290
+ down_block_res_samples = (sample,)
291
+
292
+ for downsample_block in self.down_blocks:
293
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
294
+ # For t2i-adapter CrossAttnDownBlock2D
295
+ additional_residuals = {}
296
+ if is_adapter and len(down_block_additional_residuals) > 0:
297
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
298
+
299
+ sample, res_samples = downsample_block(
300
+ hidden_states=sample,
301
+ temb=emb,
302
+ encoder_hidden_states=encoder_hidden_states,
303
+ attention_mask=attention_mask,
304
+ cross_attention_kwargs=cross_attention_kwargs,
305
+ encoder_attention_mask=encoder_attention_mask,
306
+ **additional_residuals,
307
+ )
308
+ else:
309
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
310
+
311
+ if is_adapter and len(down_block_additional_residuals) > 0:
312
+ sample += down_block_additional_residuals.pop(0)
313
+
314
+ down_block_res_samples += res_samples
315
+
316
+ if is_controlnet:
317
+ new_down_block_res_samples = ()
318
+
319
+ for down_block_res_sample, down_block_additional_residual in zip(
320
+ down_block_res_samples, down_block_additional_residuals
321
+ ):
322
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
323
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
324
+
325
+ down_block_res_samples = new_down_block_res_samples
326
+
327
+ # 4. mid
328
+ if self.mid_block is not None:
329
+ sample = self.mid_block(
330
+ sample,
331
+ emb,
332
+ encoder_hidden_states=encoder_hidden_states,
333
+ attention_mask=attention_mask,
334
+ cross_attention_kwargs=cross_attention_kwargs,
335
+ encoder_attention_mask=encoder_attention_mask,
336
+ )
337
+ # To support T2I-Adapter-XL
338
+ if (
339
+ is_adapter
340
+ and len(down_block_additional_residuals) > 0
341
+ and sample.shape == down_block_additional_residuals[0].shape
342
+ ):
343
+ sample += down_block_additional_residuals.pop(0)
344
+
345
+ if is_controlnet:
346
+ sample = sample + mid_block_additional_residual
347
+
348
+ # 5. up
349
+ for i, upsample_block in enumerate(self.up_blocks):
350
+ is_final_block = i == len(self.up_blocks) - 1
351
+
352
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
353
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
354
+
355
+ # if we have not reached the final block and need to forward the
356
+ # upsample size, we do it here
357
+ if not is_final_block and forward_upsample_size:
358
+ upsample_size = down_block_res_samples[-1].shape[2:]
359
+
360
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
361
+ sample = upsample_block(
362
+ hidden_states=sample,
363
+ temb=emb,
364
+ res_hidden_states_tuple=res_samples,
365
+ encoder_hidden_states=encoder_hidden_states,
366
+ cross_attention_kwargs=cross_attention_kwargs,
367
+ upsample_size=upsample_size,
368
+ attention_mask=attention_mask,
369
+ encoder_attention_mask=encoder_attention_mask,
370
+ )
371
+ else:
372
+ sample = upsample_block(
373
+ hidden_states=sample,
374
+ temb=emb,
375
+ res_hidden_states_tuple=res_samples,
376
+ upsample_size=upsample_size,
377
+ scale=lora_scale,
378
+ )
379
+
380
+ out_list.append(sample)
381
+
382
+ # 6. post-process
383
+ if self.conv_norm_out:
384
+ sample = self.conv_norm_out(sample)
385
+ sample = self.conv_act(sample)
386
+ sample = self.conv_out(sample)
387
+
388
+ if not return_dict:
389
+ return (sample,)
390
+
391
+ return UNet2DConditionOutput(sample=sample), out_list
392
+
393
+ self.forward = forward
394
+
395
+
396
+
397
+
398
+
399
+
400
+
401
+ class AttentionControl(abc.ABC):
402
+
403
+ def step_callback(self, x_t):
404
+ return x_t
405
+
406
+ def between_steps(self):
407
+ return
408
+
409
+ @property
410
+ def num_uncond_att_layers(self):
411
+ return 0
412
+
413
+ @abc.abstractmethod
414
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
415
+ raise NotImplementedError
416
+
417
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
418
+ if self.cur_att_layer >= self.num_uncond_att_layers:
419
+ # self.forward(attn, is_cross, place_in_unet)
420
+ if LOW_RESOURCE:
421
+ attn = self.forward(attn, is_cross, place_in_unet)
422
+ else:
423
+ h = attn.shape[0]
424
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
425
+ self.cur_att_layer += 1
426
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
427
+ self.cur_att_layer = 0
428
+ self.cur_step += 1
429
+ self.between_steps()
430
+ return attn
431
+
432
+ def reset(self):
433
+ self.cur_step = 0
434
+ self.cur_att_layer = 0
435
+
436
+ def __init__(self):
437
+ self.cur_step = 0
438
+ self.num_att_layers = -1
439
+ self.cur_att_layer = 0
440
+
441
+
442
+ class EmptyControl(AttentionControl):
443
+
444
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
445
+ return attn
446
+
447
+
448
+ class AttentionStore(AttentionControl):
449
+
450
+ @staticmethod
451
+ def get_empty_store():
452
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
453
+ "down_self": [], "mid_self": [], "up_self": []}
454
+
455
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
456
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
457
+ if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead
458
+ self.step_store[key].append(attn)
459
+ return attn
460
+
461
+ def between_steps(self):
462
+ self.attention_store = self.step_store
463
+ if self.save_global_store:
464
+ with torch.no_grad():
465
+ if len(self.global_store) == 0:
466
+ self.global_store = self.step_store
467
+ else:
468
+ for key in self.global_store:
469
+ for i in range(len(self.global_store[key])):
470
+ self.global_store[key][i] += self.step_store[key][i].detach()
471
+ self.step_store = self.get_empty_store()
472
+ self.step_store = self.get_empty_store()
473
+
474
+ def get_average_attention(self):
475
+ average_attention = self.attention_store
476
+ return average_attention
477
+
478
+ def get_average_global_attention(self):
479
+ average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
480
+ self.attention_store}
481
+ return average_attention
482
+
483
+ def reset(self):
484
+ super(AttentionStore, self).reset()
485
+ self.step_store = self.get_empty_store()
486
+ self.attention_store = {}
487
+ self.global_store = {}
488
+
489
+ def __init__(self, max_size=32, save_global_store=False):
490
+ '''
491
+ Initialize an empty AttentionStore
492
+ :param step_index: used to visualize only a specific step in the diffusion process
493
+ '''
494
+ super(AttentionStore, self).__init__()
495
+ self.save_global_store = save_global_store
496
+ self.max_size = max_size
497
+ self.step_store = self.get_empty_store()
498
+ self.attention_store = {}
499
+ self.global_store = {}
500
+ self.curr_step_index = 0
501
+
502
+ def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
503
+ out = []
504
+ attention_maps = attention_store.get_average_attention()
505
+ num_pixels = res ** 2
506
+ for location in from_where:
507
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
508
+ if item.shape[1] == num_pixels:
509
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
510
+ out.append(cross_maps)
511
+ out = torch.cat(out, dim=0)
512
+ out = out.sum(0) / out.shape[0]
513
+ return out
514
+
515
+ def aggregate_attention1(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
516
+ out = []
517
+ attention_maps = attention_store.get_average_attention()
518
+ num_pixels = res ** 2
519
+ for location in from_where:
520
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
521
+ if item.shape[1] == num_pixels:
522
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
523
+ out.append(cross_maps)
524
+ # out = torch.cat(out, dim=0)
525
+ # out = out.sum(0) / out.shape[0]
526
+ out = out[1]
527
+ out = out.sum(0) / out.shape[0]
528
+ return out
529
+
530
+
531
+ def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
532
+ tokens = tokenizer.encode(prompts[select])
533
+ decoder = tokenizer.decode
534
+ attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
535
+ images = []
536
+ for i in range(len(tokens)):
537
+ image = attention_maps[:, :, i]
538
+ image = 255 * image / image.max()
539
+ image = image.unsqueeze(-1).expand(*image.shape, 3)
540
+ image = image.numpy().astype(np.uint8)
541
+ image = np.array(Image.fromarray(image).resize((256, 256)))
542
+ image = text_under_image(image, decoder(int(tokens[i])))
543
+ images.append(image)
544
+ view_images(np.stack(images, axis=0))
545
+
546
+
547
+ def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
548
+ max_com=10, select: int = 0):
549
+ attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
550
+ u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
551
+ images = []
552
+ for i in range(max_com):
553
+ image = vh[i].reshape(res, res)
554
+ image = image - image.min()
555
+ image = 255 * image / image.max()
556
+ image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
557
+ image = Image.fromarray(image).resize((256, 256))
558
+ image = np.array(image)
559
+ images.append(image)
560
+ view_images(np.concatenate(images, axis=1))
561
+
562
+ def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
563
+ h, w, c = image.shape
564
+ offset = int(h * .2)
565
+ img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
566
+ font = cv2.FONT_HERSHEY_SIMPLEX
567
+ # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
568
+ img[:h] = image
569
+ textsize = cv2.getTextSize(text, font, 1, 2)[0]
570
+ text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
571
+ cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
572
+ return img
573
+
574
+
575
+ def view_images(images, num_rows=1, offset_ratio=0.02):
576
+ if type(images) is list:
577
+ num_empty = len(images) % num_rows
578
+ elif images.ndim == 4:
579
+ num_empty = images.shape[0] % num_rows
580
+ else:
581
+ images = [images]
582
+ num_empty = 0
583
+
584
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
585
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
586
+ num_items = len(images)
587
+
588
+ h, w, c = images[0].shape
589
+ offset = int(h * offset_ratio)
590
+ num_cols = num_items // num_rows
591
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
592
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
593
+ for i in range(num_rows):
594
+ for j in range(num_cols):
595
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
596
+ i * num_cols + j]
597
+
598
+ pil_img = Image.fromarray(image_)
599
+ display(pil_img)
600
+
601
+ def self_cross_attn(self_attn, cross_attn):
602
+ cross_attn = cross_attn.squeeze()
603
+ res = self_attn.shape[0]
604
+ assert res == cross_attn.shape[-1]
605
+ # cross attn [res, res] -> [res*res]
606
+ cross_attn_ = cross_attn.reshape([res*res])
607
+ # self_attn [res, res, res*res]
608
+ self_cross_attn = cross_attn_ * self_attn
609
+ self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0)
610
+ return self_cross_attn
_utils/config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ attn_dist_mode: v0
2
+ attn_positional_bias: rope
3
+ attn_positional_bias_n_spatial: 16
4
+ causal_norm: quiet_softmax
5
+ coord_dim: 2
6
+ d_model: 320
7
+ dropout: 0.0
8
+ feat_dim: 7
9
+ feat_embed_per_dim: 8
10
+ nhead: 4
11
+ num_decoder_layers: 6
12
+ num_encoder_layers: 6
13
+ pos_embed_per_dim: 32
14
+ spatial_pos_cutoff: 256
15
+ window: 4
_utils/example_config.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 1
2
+ crop_size:
3
+ - 256
4
+ - 256
5
+ detection_folders:
6
+ - TRA
7
+ dropout: 0.01
8
+ example_images: False # Slow
9
+ input_train:
10
+ - data/ctc/Fluo-N2DL-HeLa/01
11
+ input_val:
12
+ - data/ctc/Fluo-N2DL-HeLa/02
13
+ max_tokens: 2048
14
+ name: example
15
+ ndim: 2
16
+ num_decoder_layers: 5
17
+ num_encoder_layers: 5
18
+ outdir: runs
19
+ distributed: False
20
+ window: 4
_utils/load_models.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import RunConfig
2
+ import torch
3
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
4
+ import torch.nn as nn
5
+
6
+ def load_stable_diffusion_model(config: RunConfig):
7
+ device = torch.device('cpu')
8
+
9
+ if config.sd_2_1:
10
+ stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base"
11
+ else:
12
+ stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
13
+ # stable = StableCountingPipeline.from_pretrained(stable_diffusion_version).to(device)
14
+ stable = StableDiffusionPipeline.from_pretrained(stable_diffusion_version).to(device)
15
+ return stable
16
+
_utils/load_track_data.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from pathlib import Path
4
+ from natsort import natsorted
5
+ from PIL import Image
6
+ import numpy as np
7
+ import tifffile
8
+ import skimage.io as io
9
+ import torchvision.transforms as T
10
+ import cv2
11
+ from tqdm import tqdm
12
+ from models.tra_post_model.trackastra.utils import normalize_01, normalize
13
+ IMG_SIZE = 512
14
+
15
+ def _load_tiffs(folder: Path, dtype=None):
16
+ """Load a sequence of tiff files from a folder into a 3D numpy array."""
17
+ images = glob(str(folder / "*.tif"))
18
+ test_data = tifffile.imread(images[0])
19
+ if len(test_data.shape) == 3:
20
+ turn_gray = True
21
+ else:
22
+ turn_gray = False
23
+ end_frame = len(images)
24
+ if not turn_gray:
25
+ x = np.stack([
26
+ tifffile.imread(f).astype(dtype)
27
+ for f in tqdm(
28
+ sorted(folder.glob("*.tif"))[0 : end_frame : 1],
29
+ leave=False,
30
+ desc=f"Loading [0:{end_frame}]",
31
+ )
32
+ ])
33
+ else:
34
+ x = []
35
+ for f in tqdm(
36
+ sorted(folder.glob("*.tif"))[0 : end_frame : 1],
37
+ leave=False,
38
+ desc=f"Loading [0:{end_frame}]",
39
+ ):
40
+ img = tifffile.imread(f).astype(dtype)
41
+ if img.ndim == 3:
42
+ if img.shape[-1] > 3:
43
+ img = img[..., :3]
44
+ img = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2])
45
+ x.append(img)
46
+ x = np.stack(x)
47
+ return x
48
+
49
+
50
+ def load_track_images(file_dir):
51
+
52
+ # suffix_ = [".png", ".tif", ".tiff", ".jpg"]
53
+ assert len(glob(file_dir + "/*.tif")) > 0, f"No tif images found in {file_dir}"
54
+ images = natsorted(glob(file_dir + "/*.tif"))
55
+ imgs = []
56
+ imgs_raw = []
57
+ images_stable = []
58
+ # load images for seg and track
59
+ for img_path in tqdm(images, desc="Loading images"):
60
+ img = tifffile.imread(img_path)
61
+ img_raw = io.imread(img_path)
62
+
63
+ if img.dtype == 'uint16':
64
+ img = ((img - img.min()) / (img.max() - img.min() + 1e-6) * 255).astype(np.uint8)
65
+ img = np.stack([img] * 3, axis=-1)
66
+ w, h = img.shape[1], img.shape[0]
67
+ else:
68
+ img = Image.open(img_path).convert("RGB")
69
+ w, h = img.size
70
+
71
+ img = T.Compose([
72
+ T.ToTensor(),
73
+ T.Resize((IMG_SIZE, IMG_SIZE)),
74
+ ])(img)
75
+
76
+ image_stable = img - 0.5
77
+ img = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
78
+
79
+
80
+ imgs.append(img)
81
+ imgs_raw.append(img_raw)
82
+ images_stable.append(image_stable)
83
+
84
+ height = h
85
+ width = w
86
+ imgs = np.stack(imgs, axis=0)
87
+ imgs_raw = np.stack(imgs_raw, axis=0)
88
+ images_stable = np.stack(images_stable, axis=0)
89
+
90
+ # track data
91
+ imgs_ = _load_tiffs(Path(file_dir), dtype=np.float32)
92
+ imgs_01 = np.stack([
93
+ normalize_01(_x) for _x in tqdm(imgs_, desc="Normalizing", leave=False)
94
+ ])
95
+ imgs_ = np.stack([
96
+ normalize(_x) for _x in tqdm(imgs_, desc="Normalizing", leave=False)
97
+ ])
98
+
99
+ return imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width
100
+
101
+ if __name__ == "__main__":
102
+ file_dir = "data/2D+Time/DIC-C2DH-HeLa/train/DIC-C2DH-HeLa/02"
103
+ imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width = load_track_images(file_dir)
104
+ print(imgs.shape, imgs_raw.shape, images_stable.shape, imgs_.shape, imgs_01.shape, height, width)
_utils/misc_helper.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import shutil
5
+ from collections.abc import Mapping
6
+ from datetime import datetime
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.distributed as dist
11
+
12
+
13
+ def basicConfig(*args, **kwargs):
14
+ return
15
+
16
+
17
+ # To prevent duplicate logs, we mask this baseConfig setting
18
+ logging.basicConfig = basicConfig
19
+
20
+
21
+ def create_logger(name, log_file, level=logging.INFO):
22
+ log = logging.getLogger(name)
23
+ formatter = logging.Formatter(
24
+ "[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s"
25
+ )
26
+ fh = logging.FileHandler(log_file)
27
+ fh.setFormatter(formatter)
28
+ sh = logging.StreamHandler()
29
+ sh.setFormatter(formatter)
30
+ log.setLevel(level)
31
+ log.addHandler(fh)
32
+ log.addHandler(sh)
33
+ return log
34
+
35
+ def get_current_time():
36
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
37
+ return current_time
_utils/seg_eval.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def iou_torch(inst1, inst2):
5
+ inter = torch.logical_and(inst1, inst2).sum().float()
6
+ union = torch.logical_or(inst1, inst2).sum().float()
7
+ if union == 0:
8
+ return torch.tensor(float('nan'))
9
+ return inter / union
10
+
11
+ def get_instances_torch(mask):
12
+ # 返回所有非背景的 instance mask(布尔型)
13
+ ids = torch.unique(mask)
14
+ return [(mask == i) for i in ids if i != 0]
15
+
16
+ def compute_instance_miou(pred_mask, gt_mask):
17
+ # pred_mask 和 gt_mask 都是 torch.Tensor, shape [H, W], 整数类型
18
+ pred_instances = get_instances_torch(pred_mask)
19
+ gt_instances = get_instances_torch(gt_mask)
20
+
21
+ ious = []
22
+ for gt in gt_instances:
23
+ best_iou = torch.tensor(0.0).to(pred_mask.device)
24
+ for pred in pred_instances:
25
+ i = iou_torch(pred, gt)
26
+ if i > best_iou:
27
+ best_iou = i
28
+ ious.append(best_iou)
29
+
30
+ # 处理空情况
31
+ if len(ious) == 0:
32
+ return torch.tensor(float('nan'))
33
+ return torch.nanmean(torch.stack(ious))
34
+
35
+ from torch import Tensor
36
+
37
+
38
+ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
39
+ # Average of Dice coefficient for all batches, or for a single mask
40
+ assert input.size() == target.size()
41
+ assert input.dim() == 3 or not reduce_batch_first
42
+
43
+ sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
44
+
45
+ inter = 2 * (input * target).sum(dim=sum_dim)
46
+ sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
47
+ sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
48
+
49
+ dice = (inter + epsilon) / (sets_sum + epsilon)
50
+ return dice.mean()
51
+
52
+
53
+ def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
54
+ # Average of Dice coefficient for all classes
55
+ return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)
56
+
57
+
58
+ def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
59
+ # Dice loss (objective to minimize) between 0 and 1
60
+ fn = multiclass_dice_coeff if multiclass else dice_coeff
61
+ return 1 - fn(input, target, reduce_batch_first=True)
_utils/track_args.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import configargparse
2
+
3
+
4
+ def parse_train_args():
5
+ parser = configargparse.ArgumentParser(
6
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
7
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
8
+ allow_abbrev=False,
9
+ )
10
+ parser.add_argument(
11
+ "-c",
12
+ "--config",
13
+ default="_utils/example_config.yaml",
14
+ is_config_file=True,
15
+ help="config file path",
16
+ )
17
+ parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
18
+ parser.add_argument("-o", "--outdir", type=str, default="runs")
19
+ parser.add_argument("--name", type=str, help="Name to append to timestamp")
20
+ parser.add_argument("--timestamp", type=bool, default=True)
21
+ parser.add_argument(
22
+ "-m",
23
+ "--model",
24
+ type=str,
25
+ default="",
26
+ help="load this model at start (e.g. to continue training)",
27
+ )
28
+ parser.add_argument(
29
+ "--ndim", type=int, default=2, help="number of spatial dimensions"
30
+ )
31
+ parser.add_argument("-d", "--d_model", type=int, default=256)
32
+ parser.add_argument("-w", "--window", type=int, default=10)
33
+ parser.add_argument("--epochs", type=int, default=100)
34
+ parser.add_argument("--warmup_epochs", type=int, default=10)
35
+ parser.add_argument(
36
+ "--detection_folders",
37
+ type=str,
38
+ nargs="+",
39
+ default=["TRA"],
40
+ help=(
41
+ "Subfolders to search for detections. Defaults to `TRA`, which corresponds"
42
+ " to using only the GT."
43
+ ),
44
+ )
45
+ parser.add_argument("--downscale_temporal", type=int, default=1)
46
+ parser.add_argument("--downscale_spatial", type=int, default=1)
47
+ parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
48
+ parser.add_argument("--from_subfolder", action="store_true")
49
+ # parser.add_argument("--train_samples", type=int, default=50000)
50
+ parser.add_argument("--num_encoder_layers", type=int, default=6)
51
+ parser.add_argument("--num_decoder_layers", type=int, default=6)
52
+ parser.add_argument("--pos_embed_per_dim", type=int, default=32)
53
+ parser.add_argument("--feat_embed_per_dim", type=int, default=8)
54
+ parser.add_argument("--dropout", type=float, default=0.00)
55
+ parser.add_argument("--num_workers", type=int, default=4)
56
+ parser.add_argument("--batch_size", type=int, default=1)
57
+ parser.add_argument("--max_tokens", type=int, default=None)
58
+ parser.add_argument("--delta_cutoff", type=int, default=2)
59
+ parser.add_argument("--lr", type=float, default=1e-4)
60
+ parser.add_argument(
61
+ "--attn_positional_bias",
62
+ type=str,
63
+ choices=["rope", "bias", "none"],
64
+ default="rope",
65
+ )
66
+ parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
67
+ parser.add_argument("--attn_dist_mode", default="v0")
68
+ parser.add_argument("--mixedp", type=bool, default=True)
69
+ parser.add_argument("--dry", action="store_true")
70
+ parser.add_argument("--profile", action="store_true")
71
+ parser.add_argument(
72
+ "--features",
73
+ type=str,
74
+ choices=[
75
+ "none",
76
+ "regionprops",
77
+ "regionprops2",
78
+ "patch",
79
+ "patch_regionprops",
80
+ "wrfeat",
81
+ ],
82
+ default="wrfeat",
83
+ )
84
+ parser.add_argument(
85
+ "--causal_norm",
86
+ type=str,
87
+ choices=["none", "linear", "softmax", "quiet_softmax"],
88
+ default="quiet_softmax",
89
+ )
90
+ parser.add_argument("--div_upweight", type=float, default=2)
91
+
92
+ parser.add_argument("--augment", type=int, default=3)
93
+ parser.add_argument("--tracking_frequency", type=int, default=-1)
94
+
95
+ parser.add_argument("--sanity_dist", action="store_true")
96
+ parser.add_argument("--preallocate", type=bool, default=False)
97
+ parser.add_argument("--only_prechecks", action="store_true")
98
+ parser.add_argument(
99
+ "--compress", type=bool, default=True, help="compress dataset"
100
+ )
101
+
102
+
103
+ parser.add_argument("--seed", type=int, default=None)
104
+ parser.add_argument(
105
+ "--logger",
106
+ type=str,
107
+ default="tensorboard",
108
+ choices=["tensorboard", "wandb", "none"],
109
+ )
110
+ parser.add_argument("--wandb_project", type=str, default="trackastra")
111
+ parser.add_argument(
112
+ "--crop_size",
113
+ type=int,
114
+ # required=True,
115
+ nargs="+",
116
+ default=None,
117
+ help="random crop size for augmentation",
118
+ )
119
+ parser.add_argument(
120
+ "--weight_by_ndivs",
121
+ type=bool,
122
+ default=True,
123
+ help="Oversample windows that contain divisions",
124
+ )
125
+ parser.add_argument(
126
+ "--weight_by_dataset",
127
+ type=bool,
128
+ default=False,
129
+ help=(
130
+ "Inversely weight datasets by number of samples (to counter dataset size"
131
+ " imbalance)"
132
+ ),
133
+ )
134
+
135
+ args, unknown_args = parser.parse_known_args()
136
+
137
+ # # Hack to allow for --input_test
138
+ # allowed_unknown = ["input_test"]
139
+ # if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset(
140
+ # set(allowed_unknown)
141
+ # ):
142
+ # raise ValueError(f"Unknown args: {unknown_args}")
143
+
144
+ # pprint(vars(args))
145
+
146
+ # for backward compatibility
147
+ # if args.attn_positional_bias == "True":
148
+ # args.attn_positional_bias = "bias"
149
+ # elif args.attn_positional_bias == "False":
150
+ # args.attn_positional_bias = False
151
+
152
+ # if args.train_samples == 0:
153
+ # raise NotImplementedError(
154
+ # "--train_samples must be > 0, full dataset pass not supported."
155
+ # )
156
+
157
+ return args
config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from pathlib import Path
3
+ from typing import Dict, List
4
+
5
+
6
+ @dataclass
7
+ class RunConfig:
8
+ # Guiding text prompt
9
+ prompt: str = "<task-prompt>"
10
+ # Whether to use Stable Diffusion v2.1
11
+ sd_2_1: bool = False
12
+ # Which token indices to alter with attend-and-excite
13
+ token_indices: List[int] = field(default_factory=lambda: [2,5])
14
+ # Which random seeds to use when generating
15
+ seeds: List[int] = field(default_factory=lambda: [42])
16
+ # Path to save all outputs to
17
+ output_path: Path = Path('./outputs')
18
+ # Number of denoising steps
19
+ n_inference_steps: int = 50
20
+ # Text guidance scale
21
+ guidance_scale: float = 7.5
22
+ # Number of denoising steps to apply attend-and-excite
23
+ max_iter_to_alter: int = 25
24
+ # Resolution of UNet to compute attention maps over
25
+ attention_res: int = 16
26
+ # Whether to run standard SD or attend-and-excite
27
+ run_standard_sd: bool = False
28
+ # Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in
29
+ thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8})
30
+ # Scale factor for updating the denoised latent z_t
31
+ scale_factor: int = 20
32
+ # Start and end values used for scaling the scale factor - decays linearly with the denoising timestep
33
+ scale_range: tuple = field(default_factory=lambda: (1.0, 0.5))
34
+ # Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token
35
+ smooth_attentions: bool = True
36
+ # Standard deviation for the Gaussian smoothing
37
+ sigma: float = 0.5
38
+ # Kernel size for the Gaussian smoothing
39
+ kernel_size: int = 3
40
+ # Whether to save cross attention maps for the final results
41
+ save_cross_attention_maps: bool = False
42
+
43
+ def __post_init__(self):
44
+ self.output_path.mkdir(exist_ok=True, parents=True)
counting.py ADDED
@@ -0,0 +1,337 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stable diffusion x loca
2
+ import os
3
+ # os.system("source /etc/network_turbo")
4
+ os.environ["CUDA_VISIBLE_DEVICES"] = "2"
5
+ import pprint
6
+ from typing import Any, List, Optional
7
+ import argparse
8
+ import pyrallis
9
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
10
+ import torch
11
+ import os
12
+ from PIL import Image
13
+ import numpy as np
14
+ from config import RunConfig
15
+ from _utils import attn_utils_new as attn_utils
16
+ from _utils.attn_utils import AttentionStore
17
+ from _utils.misc_helper import *
18
+ import torch.nn.functional as F
19
+ import matplotlib.pyplot as plt
20
+ import cv2
21
+ import warnings
22
+ from pytorch_lightning.callbacks import ModelCheckpoint
23
+ warnings.filterwarnings("ignore", category=UserWarning)
24
+ import pytorch_lightning as pl
25
+ from _utils.load_models import load_stable_diffusion_model
26
+ from models.model import Counting_with_SD_features_loca as Counting
27
+ from pytorch_lightning.loggers import WandbLogger
28
+ from models.enc_model.loca_args import get_argparser as loca_get_argparser
29
+ from models.enc_model.loca import build_model as build_loca_model
30
+ import time
31
+ import torchvision.transforms as T
32
+ import skimage.io as io
33
+ from _utils.dummy_box_gen import gen_dummy_boxes
34
+
35
+ SCALE = 1
36
+
37
+
38
+ class CountingModule(pl.LightningModule):
39
+ def __init__(self, use_box=True):
40
+ super().__init__()
41
+ self.use_box = use_box
42
+ self.config = RunConfig() # config for stable diffusion
43
+ self.initialize_model()
44
+
45
+
46
+ def initialize_model(self):
47
+
48
+ # load loca model
49
+ loca_args = loca_get_argparser().parse_args()
50
+ self.loca_model = build_loca_model(loca_args)
51
+ # weights = torch.load("ckpt/loca_few_shot.pt")["model"]
52
+ # weights = {k.replace("module","") : v for k, v in weights.items()}
53
+ # self.loca_model.load_state_dict(weights, strict=False)
54
+ # del weights
55
+
56
+ self.counting_adapter = Counting(scale_factor=SCALE)
57
+ # if os.path.isfile(self.args.adapter_weight):
58
+ # adapter_weight = torch.load(self.args.adapter_weight,map_location=torch.device('cpu'))
59
+ # self.counting_adapter.load_state_dict(adapter_weight, strict=False)
60
+
61
+ ### load stable diffusion and its controller
62
+ self.stable = load_stable_diffusion_model(config=self.config)
63
+ self.noise_scheduler = self.stable.scheduler
64
+ self.controller = AttentionStore(max_size=64)
65
+ attn_utils.register_attention_control(self.stable, self.controller)
66
+ attn_utils.register_hier_output(self.stable)
67
+
68
+ ##### initialize token_emb #####
69
+ placeholder_token = "<task-prompt>"
70
+ self.task_token = "repetitive objects"
71
+ # Add the placeholder token in tokenizer
72
+ num_added_tokens = self.stable.tokenizer.add_tokens(placeholder_token)
73
+ if num_added_tokens == 0:
74
+ raise ValueError(
75
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
76
+ " `placeholder_token` that is not already in the tokenizer."
77
+ )
78
+ if os.path.isfile("pretrained/task_embed.pth"):
79
+ task_embed_from_pretrain = torch.load("pretrained/task_embed.pth")
80
+ placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token)
81
+ self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer))
82
+
83
+ token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data
84
+ token_embeds[placeholder_token_id] = task_embed_from_pretrain
85
+ else:
86
+ initializer_token = "count"
87
+ token_ids = self.stable.tokenizer.encode(initializer_token, add_special_tokens=False)
88
+ # Check if initializer_token is a single token or a sequence of tokens
89
+ if len(token_ids) > 1:
90
+ raise ValueError("The initializer token must be a single token.")
91
+
92
+ initializer_token_id = token_ids[0]
93
+ placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token)
94
+
95
+ self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer))
96
+
97
+ token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data
98
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
99
+
100
+ # others
101
+ self.placeholder_token = placeholder_token
102
+ self.placeholder_token_id = placeholder_token_id
103
+
104
+
105
+ def move_to_device(self, device):
106
+ self.stable.to(device)
107
+ if self.loca_model is not None and self.counting_adapter is not None:
108
+ self.loca_model.to(device)
109
+ self.counting_adapter.to(device)
110
+ self.to(device)
111
+
112
+ def forward(self, data_path, box=None):
113
+ filename = data_path.split("/")[-1]
114
+ img = Image.open(data_path).convert("RGB")
115
+ width, height = img.size
116
+ input_image = T.Compose([T.ToTensor(), T.Resize((512, 512))])(img)
117
+ input_image_stable = input_image - 0.5
118
+ input_image = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_image)
119
+ if box is not None:
120
+ boxes = torch.tensor(box) / torch.tensor([width, height, width, height]) * 512 # xyxy, normalized
121
+ assert self.use_box == True
122
+ else:
123
+ boxes = torch.tensor([[100,100,130,130], [200,200,250,250]], dtype=torch.float32) # dummy box
124
+ assert self.use_box == False
125
+
126
+ # move to device
127
+ input_image = input_image.unsqueeze(0).to(self.device)
128
+ boxes = boxes.unsqueeze(0).to(self.device)
129
+ input_image_stable = input_image_stable.unsqueeze(0).to(self.device)
130
+
131
+
132
+
133
+ latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach()
134
+ latents = latents * 0.18215
135
+ # Sample noise that we'll add to the latents
136
+ noise = torch.randn_like(latents)
137
+ timesteps = torch.tensor([20], device=latents.device).long()
138
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
139
+ input_ids_ = self.stable.tokenizer(
140
+ self.placeholder_token + " repetitive objects",
141
+ # "object",
142
+ padding="max_length",
143
+ truncation=True,
144
+ max_length=self.stable.tokenizer.model_max_length,
145
+ return_tensors="pt",
146
+ )
147
+ input_ids = input_ids_["input_ids"].to(self.device)
148
+ attention_mask = input_ids_["attention_mask"].to(self.device)
149
+ encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0]
150
+
151
+ input_image = input_image.to(self.device)
152
+ boxes = boxes.to(self.device)
153
+
154
+
155
+ task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id)
156
+ if self.use_box:
157
+ loca_out = self.loca_model.forward_before_reg(input_image, boxes)
158
+ loca_feature_bf_regression = loca_out["feature_bf_regression"]
159
+ adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768]
160
+ if task_loc_idx.shape[0] == 0:
161
+ encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位
162
+ else:
163
+ encoder_hidden_states[0,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位
164
+
165
+ # Predict the noise residual
166
+ noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states)
167
+ noise_pred = noise_pred.sample
168
+ attention_store = self.controller.attention_store
169
+
170
+
171
+ attention_maps = []
172
+ exemplar_attention_maps = []
173
+ exemplar_attention_maps1 = []
174
+ exemplar_attention_maps2 = []
175
+ exemplar_attention_maps3 = []
176
+
177
+ cross_self_task_attn_maps = []
178
+ cross_self_exe_attn_maps = []
179
+
180
+ # only use 64x64 self-attention
181
+ self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
182
+ prompts=[self.config.prompt], # 这里要改么
183
+ attention_store=self.controller,
184
+ res=64,
185
+ from_where=("up", "down"),
186
+ is_cross=False,
187
+ select=0
188
+ )
189
+ self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096]
190
+ prompts=[self.config.prompt], # 这里要改么
191
+ attention_store=self.controller,
192
+ res=32,
193
+ from_where=("up", "down"),
194
+ is_cross=False,
195
+ select=0
196
+ )
197
+ self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096]
198
+ prompts=[self.config.prompt], # 这里要改么
199
+ attention_store=self.controller,
200
+ res=16,
201
+ from_where=("up", "down"),
202
+ is_cross=False,
203
+ select=0
204
+ )
205
+
206
+ # cross attention
207
+ for res in [32, 16]:
208
+ attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
209
+ prompts=[self.config.prompt], # 这里要改么
210
+ attention_store=self.controller,
211
+ res=res,
212
+ from_where=("up", "down"),
213
+ is_cross=True,
214
+ select=0
215
+ )
216
+
217
+ task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
218
+ attention_maps.append(task_attn_)
219
+ if self.use_box:
220
+ exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn
221
+ exemplar_attention_maps.append(exemplar_attns)
222
+ else:
223
+ exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0)
224
+ exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0)
225
+ exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0)
226
+ exemplar_attention_maps1.append(exemplar_attns1)
227
+ exemplar_attention_maps2.append(exemplar_attns2)
228
+ exemplar_attention_maps3.append(exemplar_attns3)
229
+
230
+
231
+ scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))]
232
+ attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))])
233
+ task_attn_64 = torch.mean(attns, dim=0, keepdim=True)
234
+ cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64)
235
+ cross_self_task_attn_maps.append(cross_self_task_attn)
236
+
237
+ if self.use_box:
238
+ scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))]
239
+ attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))])
240
+ exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True)
241
+
242
+ cross_self_exe_attn = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64)
243
+ cross_self_exe_attn_maps.append(cross_self_exe_attn)
244
+ else:
245
+ scale_factors = [(64 // exemplar_attention_maps1[i].shape[-1]) for i in range(len(exemplar_attention_maps1))]
246
+ attns = torch.cat([F.interpolate(exemplar_attention_maps1[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps1))])
247
+ exemplar_attn_64_1 = torch.mean(attns, dim=0, keepdim=True)
248
+
249
+ scale_factors = [(64 // exemplar_attention_maps2[i].shape[-1]) for i in range(len(exemplar_attention_maps2))]
250
+ attns = torch.cat([F.interpolate(exemplar_attention_maps2[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps2))])
251
+ exemplar_attn_64_2 = torch.mean(attns, dim=0, keepdim=True)
252
+
253
+ scale_factors = [(64 // exemplar_attention_maps3[i].shape[-1]) for i in range(len(exemplar_attention_maps3))]
254
+ attns = torch.cat([F.interpolate(exemplar_attention_maps3[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps3))])
255
+ exemplar_attn_64_3 = torch.mean(attns, dim=0, keepdim=True)
256
+
257
+
258
+ cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64)
259
+ cross_self_task_attn_maps.append(cross_self_task_attn)
260
+
261
+ # if self.args.merge_exemplar == "average":
262
+ cross_self_exe_attn1 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_1)
263
+ cross_self_exe_attn2 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_2)
264
+ cross_self_exe_attn3 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_3)
265
+ exemplar_attn_64 = (exemplar_attn_64_1 + exemplar_attn_64_2 + exemplar_attn_64_3) / 3
266
+ cross_self_exe_attn = (cross_self_exe_attn1 + cross_self_exe_attn2 + cross_self_exe_attn3) / 3
267
+
268
+ exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6)
269
+
270
+ attn_stack = [exemplar_attn_64 / 2, cross_self_exe_attn / 2, exemplar_attn_64, cross_self_exe_attn]
271
+ attn_stack = torch.cat(attn_stack, dim=1)
272
+
273
+ if not self.use_box:
274
+
275
+ # cross_self_exe_attn_np = cross_self_exe_attn.detach().squeeze().cpu().numpy()
276
+ # boxes = gen_dummy_boxes(cross_self_exe_attn_np, max_boxes=1)
277
+ # boxes = boxes.to(self.device)
278
+
279
+ loca_out = self.loca_model.forward_before_reg(input_image, boxes)
280
+ loca_feature_bf_regression = loca_out["feature_bf_regression"]
281
+ attn_out = self.loca_model.forward_reg(loca_out, attn_stack, feature_list[-1])
282
+ pred_density = attn_out["pred"].squeeze().cpu().numpy()
283
+ pred_cnt = pred_density.sum().item()
284
+
285
+ # resize pred_density to original image size
286
+ pred_density_rsz = cv2.resize(pred_density, (width, height), interpolation=cv2.INTER_CUBIC)
287
+ pred_density_rsz = pred_density_rsz / pred_density_rsz.sum() * pred_cnt
288
+
289
+ return pred_density_rsz, pred_cnt
290
+
291
+
292
+ def inference(data_path, box=None, save_path="./example_imgs", visualize=False):
293
+ if box is not None:
294
+ use_box = True
295
+ else:
296
+ use_box = False
297
+ model = CountingModule(use_box=use_box)
298
+ load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_cnt.pth"), strict=True)
299
+ model.eval()
300
+ with torch.no_grad():
301
+ density_map, cnt = model(data_path, box)
302
+
303
+ if visualize:
304
+ img = io.imread(data_path)
305
+ if len(img.shape) == 3 and img.shape[2] > 3:
306
+ img = img[:,:,:3]
307
+ if len(img.shape) == 2:
308
+ img = np.stack([img]*3, axis=-1)
309
+ img_show = img.squeeze()
310
+ density_map_show = density_map.squeeze()
311
+ os.makedirs(save_path, exist_ok=True)
312
+ filename = data_path.split("/")[-1]
313
+ img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show))
314
+ fig, ax = plt.subplots(1,2, figsize=(12,6))
315
+ ax[0].imshow(img_show)
316
+ ax[0].axis('off')
317
+ ax[0].set_title(f"Input image")
318
+ ax[1].imshow(img_show)
319
+ ax[1].imshow(density_map_show, cmap='jet', alpha=0.5) # Overlay density map with some transparency
320
+ ax[1].axis('off')
321
+ ax[1].set_title(f"Predicted density map, count: {cnt:.1f}")
322
+ plt.tight_layout()
323
+ plt.savefig(os.path.join(save_path, filename.split(".")[0]+"_cnt.png"), dpi=300)
324
+ plt.close()
325
+ return density_map
326
+
327
+ def main():
328
+
329
+ inference(
330
+ data_path = "example_imgs/1977_Well_F-5_Field_1.png",
331
+ # box=[[150, 60, 183, 87]],
332
+ save_path = "./example_imgs",
333
+ visualize = True
334
+ )
335
+
336
+ if __name__ == "__main__":
337
+ main()
example_imgs/1977_Well_F-5_Field_1.png ADDED

Git LFS Details

  • SHA256: 145a99e724048ed40db7843e57a1d93cd2e1f6e221d167a29b732740d6302c52
  • Pointer size: 132 Bytes
  • Size of remote file: 2.43 MB
example_imgs/1977_Well_F-5_Field_1_seg.png ADDED

Git LFS Details

  • SHA256: ba965afa28d3b51f683d3e98fe28cf644c66c9625a215406639b5b6d3087dab9
  • Pointer size: 132 Bytes
  • Size of remote file: 3.48 MB
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/enc_model/__init__.py ADDED
File without changes
models/enc_model/backbone.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torchvision import models
5
+ from torchvision.ops.misc import FrozenBatchNorm2d
6
+
7
+
8
+ class Backbone(nn.Module):
9
+
10
+ def __init__(
11
+ self,
12
+ name: str,
13
+ pretrained: bool,
14
+ dilation: bool,
15
+ reduction: int,
16
+ swav: bool,
17
+ requires_grad: bool
18
+ ):
19
+
20
+ super(Backbone, self).__init__()
21
+
22
+ resnet = getattr(models, name)(
23
+ replace_stride_with_dilation=[False, False, dilation],
24
+ pretrained=pretrained, norm_layer=FrozenBatchNorm2d
25
+ )
26
+
27
+ self.backbone = resnet
28
+ self.reduction = reduction
29
+
30
+ if name == 'resnet50' and swav:
31
+ checkpoint = torch.hub.load_state_dict_from_url(
32
+ 'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar',
33
+ map_location="cpu"
34
+ )
35
+ state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
36
+ self.backbone.load_state_dict(state_dict, strict=False)
37
+
38
+ # concatenation of layers 2, 3 and 4
39
+ self.num_channels = 896 if name in ['resnet18', 'resnet34'] else 3584
40
+
41
+ for n, param in self.backbone.named_parameters():
42
+ if 'layer2' not in n and 'layer3' not in n and 'layer4' not in n:
43
+ param.requires_grad_(False)
44
+ else:
45
+ param.requires_grad_(requires_grad)
46
+
47
+ def forward(self, x):
48
+ size = x.size(-2) // self.reduction, x.size(-1) // self.reduction
49
+ x = self.backbone.conv1(x)
50
+ x = self.backbone.bn1(x)
51
+ x = self.backbone.relu(x)
52
+ x = self.backbone.maxpool(x)
53
+
54
+ x = self.backbone.layer1(x)
55
+ x = layer2 = self.backbone.layer2(x)
56
+ x = layer3 = self.backbone.layer3(x)
57
+ x = layer4 = self.backbone.layer4(x)
58
+
59
+ x = torch.cat([
60
+ F.interpolate(f, size=size, mode='bilinear', align_corners=True)
61
+ for f in [layer2, layer3, layer4]
62
+ ], dim=1)
63
+
64
+ return x
models/enc_model/loca.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .backbone import Backbone
2
+ from .transformer import TransformerEncoder
3
+ from .ope import OPEModule
4
+ from .positional_encoding import PositionalEncodingsFixed
5
+ from .regression_head import DensityMapRegressor
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ class LOCA(nn.Module):
13
+
14
+ def __init__(
15
+ self,
16
+ image_size: int,
17
+ num_encoder_layers: int,
18
+ num_ope_iterative_steps: int,
19
+ num_objects: int,
20
+ emb_dim: int,
21
+ num_heads: int,
22
+ kernel_dim: int,
23
+ backbone_name: str,
24
+ swav_backbone: bool,
25
+ train_backbone: bool,
26
+ reduction: int,
27
+ dropout: float,
28
+ layer_norm_eps: float,
29
+ mlp_factor: int,
30
+ norm_first: bool,
31
+ activation: nn.Module,
32
+ norm: bool,
33
+ zero_shot: bool,
34
+ ):
35
+
36
+ super(LOCA, self).__init__()
37
+
38
+ self.emb_dim = emb_dim
39
+ self.num_objects = num_objects
40
+ self.reduction = reduction
41
+ self.kernel_dim = kernel_dim
42
+ self.image_size = image_size
43
+ self.zero_shot = zero_shot
44
+ self.num_heads = num_heads
45
+ self.num_encoder_layers = num_encoder_layers
46
+
47
+ self.backbone = Backbone(
48
+ backbone_name, pretrained=True, dilation=False, reduction=reduction,
49
+ swav=swav_backbone, requires_grad=train_backbone
50
+ )
51
+ self.input_proj = nn.Conv2d(
52
+ self.backbone.num_channels, emb_dim, kernel_size=1
53
+ )
54
+
55
+ if num_encoder_layers > 0:
56
+ self.encoder = TransformerEncoder(
57
+ num_encoder_layers, emb_dim, num_heads, dropout, layer_norm_eps,
58
+ mlp_factor, norm_first, activation, norm
59
+ )
60
+
61
+ self.ope = OPEModule(
62
+ num_ope_iterative_steps, emb_dim, kernel_dim, num_objects, num_heads,
63
+ reduction, layer_norm_eps, mlp_factor, norm_first, activation, norm, zero_shot
64
+ )
65
+
66
+ self.regression_head = DensityMapRegressor(emb_dim, reduction)
67
+ self.aux_heads = nn.ModuleList([
68
+ DensityMapRegressor(emb_dim, reduction)
69
+ for _ in range(num_ope_iterative_steps - 1)
70
+ ])
71
+
72
+ self.pos_emb = PositionalEncodingsFixed(emb_dim)
73
+
74
+ self.attn_norm = nn.LayerNorm(normalized_shape=(64, 64))
75
+ self.fuse = nn.Sequential(
76
+ nn.Conv2d(324, 256, kernel_size=1, stride=1),
77
+ nn.LeakyReLU(),
78
+ nn.LayerNorm((64, 64))
79
+ )
80
+
81
+ # self.fuse1 = nn.Sequential(
82
+ # nn.Conv2d(322, 256, kernel_size=1, stride=1),
83
+ # nn.LeakyReLU(),
84
+ # nn.LayerNorm((64, 64))
85
+ # )
86
+
87
+ def forward_before_reg(self, x, bboxes):
88
+ num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects
89
+ # backbone
90
+ backbone_features = self.backbone(x)
91
+ # prepare the encoder input
92
+ src = self.input_proj(backbone_features)
93
+ bs, c, h, w = src.size()
94
+ pos_emb = self.pos_emb(bs, h, w, src.device).flatten(2).permute(2, 0, 1)
95
+ src = src.flatten(2).permute(2, 0, 1)
96
+
97
+ # push through the encoder
98
+ if self.num_encoder_layers > 0:
99
+ image_features = self.encoder(src, pos_emb, src_key_padding_mask=None, src_mask=None)
100
+ else:
101
+ image_features = src
102
+
103
+ # prepare OPE input
104
+ f_e = image_features.permute(1, 2, 0).reshape(-1, self.emb_dim, h, w)
105
+
106
+ all_prototypes = self.ope(f_e, pos_emb, bboxes) # [3, 27, 1, 256]
107
+
108
+ outputs = list()
109
+ response_maps_list = []
110
+ for i in range(all_prototypes.size(0)):
111
+ prototypes = all_prototypes[i, ...].permute(1, 0, 2).reshape(
112
+ bs, num_objects, self.kernel_dim, self.kernel_dim, -1
113
+ ).permute(0, 1, 4, 2, 3).flatten(0, 2)[:, None, ...] # [768, 1, 3, 3]
114
+
115
+ response_maps = F.conv2d(
116
+ torch.cat([f_e for _ in range(num_objects)], dim=1).flatten(0, 1).unsqueeze(0),
117
+ prototypes,
118
+ bias=None,
119
+ padding=self.kernel_dim // 2,
120
+ groups=prototypes.size(0)
121
+ ).view(
122
+ bs, num_objects, self.emb_dim, h, w
123
+ ).max(dim=1)[0]
124
+
125
+ # # send through regression heads
126
+ # if i == all_prototypes.size(0) - 1:
127
+ # predicted_dmaps = self.regression_head(response_maps)
128
+ # else:
129
+ # predicted_dmaps = self.aux_heads[i](response_maps)
130
+ # outputs.append(predicted_dmaps)
131
+ response_maps_list.append(response_maps)
132
+
133
+ out = {
134
+ # "pred": outputs[-1],
135
+ "feature_bf_regression": response_maps_list[-1],
136
+ # "aux_pred": outputs[:-1],
137
+ "aux_feature_bf_regression": response_maps_list[:-1]
138
+ }
139
+
140
+ return out
141
+
142
+ def forward_reg(self, response_maps, attn_stack, unet_feature):
143
+ attn_stack = self.attn_norm(attn_stack)
144
+ attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
145
+ unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
146
+ unet_feature = unet_feature * attn_stack_mean
147
+ if unet_feature.shape[1] == 322:
148
+ unet_feature = self.fuse1(unet_feature)
149
+ else:
150
+ unet_feature = self.fuse(unet_feature)
151
+
152
+ response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
153
+
154
+ outputs = []
155
+ for i in range(len(response_maps)):
156
+ response_map = response_maps[i] + unet_feature
157
+ if i == len(response_maps) - 1:
158
+ predicted_dmaps = self.regression_head(response_map)
159
+ else:
160
+ predicted_dmaps = self.aux_heads[i](response_map)
161
+ outputs.append(predicted_dmaps)
162
+
163
+ return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
164
+
165
+ def forward_reg1(self, response_maps, self_attn):
166
+ # attn_stack = self.attn_norm(attn_stack)
167
+ # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
168
+ # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
169
+ # unet_feature = unet_feature * attn_stack_mean
170
+ # if unet_feature.shape[1] == 322:
171
+ # unet_feature = self.fuse1(unet_feature)
172
+ # else:
173
+ # unet_feature = self.fuse(unet_feature)
174
+
175
+
176
+
177
+ response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
178
+
179
+ outputs = []
180
+ for i in range(len(response_maps)):
181
+ response_map = response_maps[i] + self_attn
182
+ if i == len(response_maps) - 1:
183
+ predicted_dmaps = self.regression_head(response_map)
184
+ else:
185
+ predicted_dmaps = self.aux_heads[i](response_map)
186
+ outputs.append(predicted_dmaps)
187
+
188
+ return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
189
+
190
+ def forward_reg_without_unet(self, response_maps, attn_stack):
191
+ # attn_stack = self.attn_norm(attn_stack)
192
+ attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
193
+
194
+ response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
195
+
196
+ outputs = []
197
+ for i in range(len(response_maps)):
198
+ response_map = response_maps[i] * attn_stack_mean * 0.5 + response_maps[i]
199
+ if i == len(response_maps) - 1:
200
+ predicted_dmaps = self.regression_head(response_map)
201
+ else:
202
+ predicted_dmaps = self.aux_heads[i](response_map)
203
+ outputs.append(predicted_dmaps)
204
+
205
+ return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
206
+
207
+
208
+ def build_model(args):
209
+
210
+ assert args.backbone in ['resnet18', 'resnet50', 'resnet101']
211
+ assert args.reduction in [4, 8, 16]
212
+
213
+ return LOCA(
214
+ image_size=args.image_size,
215
+ num_encoder_layers=args.num_enc_layers,
216
+ num_ope_iterative_steps=args.num_ope_iterative_steps,
217
+ num_objects=args.num_objects,
218
+ zero_shot=args.zero_shot,
219
+ emb_dim=args.emb_dim,
220
+ num_heads=args.num_heads,
221
+ kernel_dim=args.kernel_dim,
222
+ backbone_name=args.backbone,
223
+ swav_backbone=args.swav_backbone,
224
+ train_backbone=args.backbone_lr > 0,
225
+ reduction=args.reduction,
226
+ dropout=args.dropout,
227
+ layer_norm_eps=1e-5,
228
+ mlp_factor=8,
229
+ norm_first=args.pre_norm,
230
+ activation=nn.GELU,
231
+ norm=True,
232
+ )
models/enc_model/loca_args.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_argparser():
5
+
6
+ parser = argparse.ArgumentParser("LOCA parser", add_help=False)
7
+
8
+ parser.add_argument('--model_name', default='loca_few_shot', type=str)
9
+ parser.add_argument(
10
+ '--data_path',
11
+ default='./data/FSC147_384_V2',
12
+ type=str
13
+ )
14
+ parser.add_argument(
15
+ '--model_path',
16
+ default='ckpt',
17
+ type=str
18
+ )
19
+ parser.add_argument('--backbone', default='resnet50', type=str)
20
+ parser.add_argument('--swav_backbone', action='store_true', default=True)
21
+ parser.add_argument('--reduction', default=8, type=int)
22
+ parser.add_argument('--image_size', default=512, type=int)
23
+ parser.add_argument('--num_enc_layers', default=3, type=int)
24
+ parser.add_argument('--num_ope_iterative_steps', default=3, type=int)
25
+ parser.add_argument('--emb_dim', default=256, type=int)
26
+ parser.add_argument('--num_heads', default=8, type=int)
27
+ parser.add_argument('--kernel_dim', default=3, type=int)
28
+ parser.add_argument('--num_objects', default=3, type=int)
29
+ parser.add_argument('--epochs', default=200, type=int)
30
+ parser.add_argument('--resume_training', action='store_true')
31
+ parser.add_argument('--lr', default=1e-4, type=float)
32
+ parser.add_argument('--backbone_lr', default=0, type=float)
33
+ parser.add_argument('--lr_drop', default=200, type=int)
34
+ parser.add_argument('--weight_decay', default=1e-4, type=float)
35
+ parser.add_argument('--batch_size', default=1, type=int)
36
+ parser.add_argument('--dropout', default=0.1, type=float)
37
+ parser.add_argument('--num_workers', default=8, type=int)
38
+ parser.add_argument('--max_grad_norm', default=0.1, type=float)
39
+ parser.add_argument('--aux_weight', default=0.3, type=float)
40
+ parser.add_argument('--tiling_p', default=0.5, type=float)
41
+ parser.add_argument('--zero_shot', action='store_true')
42
+ parser.add_argument('--pre_norm', action='store_true', default=True)
43
+
44
+ return parser
models/enc_model/mlp.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class MLP(nn.Module):
5
+
6
+ def __init__(
7
+ self,
8
+ input_dim: int,
9
+ hidden_dim: int,
10
+ dropout: float,
11
+ activation: nn.Module
12
+ ):
13
+ super(MLP, self).__init__()
14
+
15
+ self.linear1 = nn.Linear(input_dim, hidden_dim)
16
+ self.linear2 = nn.Linear(hidden_dim, input_dim)
17
+ self.dropout = nn.Dropout(dropout)
18
+ self.activation = activation()
19
+
20
+ def forward(self, x):
21
+ return (
22
+ self.linear2(self.dropout(self.activation(self.linear1(x))))
23
+ )
models/enc_model/ope.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mlp import MLP
2
+ from .positional_encoding import PositionalEncodingsFixed
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from torchvision.ops import roi_align
8
+
9
+
10
+ class OPEModule(nn.Module):
11
+
12
+ def __init__(
13
+ self,
14
+ num_iterative_steps: int,
15
+ emb_dim: int,
16
+ kernel_dim: int,
17
+ num_objects: int,
18
+ num_heads: int,
19
+ reduction: int,
20
+ layer_norm_eps: float,
21
+ mlp_factor: int,
22
+ norm_first: bool,
23
+ activation: nn.Module,
24
+ norm: bool,
25
+ zero_shot: bool,
26
+ ):
27
+
28
+ super(OPEModule, self).__init__()
29
+
30
+ self.num_iterative_steps = num_iterative_steps
31
+ self.zero_shot = zero_shot
32
+ self.kernel_dim = kernel_dim
33
+ self.num_objects = num_objects
34
+ self.emb_dim = emb_dim
35
+ self.reduction = reduction
36
+
37
+ if num_iterative_steps > 0:
38
+ self.iterative_adaptation = IterativeAdaptationModule(
39
+ num_layers=num_iterative_steps, emb_dim=emb_dim, num_heads=num_heads,
40
+ dropout=0, layer_norm_eps=layer_norm_eps,
41
+ mlp_factor=mlp_factor, norm_first=norm_first,
42
+ activation=activation, norm=norm,
43
+ zero_shot=zero_shot
44
+ )
45
+
46
+ if not self.zero_shot:
47
+ self.shape_or_objectness = nn.Sequential(
48
+ nn.Linear(2, 64),
49
+ nn.ReLU(),
50
+ nn.Linear(64, emb_dim),
51
+ nn.ReLU(),
52
+ nn.Linear(emb_dim, self.kernel_dim**2 * emb_dim)
53
+ )
54
+ else:
55
+ self.shape_or_objectness = nn.Parameter(
56
+ torch.empty((self.num_objects, self.kernel_dim**2, emb_dim))
57
+ )
58
+ nn.init.normal_(self.shape_or_objectness)
59
+
60
+ self.pos_emb = PositionalEncodingsFixed(emb_dim)
61
+
62
+ def forward(self, f_e, pos_emb, bboxes):
63
+ bs, _, h, w = f_e.size()
64
+ # extract the shape features or objectness
65
+ if not self.zero_shot:
66
+ box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device)
67
+ box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0]
68
+ box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1]
69
+ shape_or_objectness = self.shape_or_objectness(box_hw).reshape(
70
+ bs, -1, self.kernel_dim ** 2, self.emb_dim
71
+ ).flatten(1, 2).transpose(0, 1)
72
+ else:
73
+ shape_or_objectness = self.shape_or_objectness.expand(
74
+ bs, -1, -1, -1
75
+ ).flatten(1, 2).transpose(0, 1)
76
+
77
+ # if not zero shot add appearance
78
+ if not self.zero_shot:
79
+ # reshape bboxes into the format suitable for roi_align
80
+ num_of_boxes = bboxes.size(1)
81
+ bboxes = torch.cat([
82
+ torch.arange(
83
+ bs, requires_grad=False
84
+ ).to(bboxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
85
+ bboxes.flatten(0, 1),
86
+ ], dim=1)
87
+ appearance = roi_align(
88
+ f_e,
89
+ boxes=bboxes, output_size=self.kernel_dim,
90
+ spatial_scale=1.0 / self.reduction, aligned=True
91
+ ).permute(0, 2, 3, 1).reshape(
92
+ bs, num_of_boxes * self.kernel_dim ** 2, -1
93
+ ).transpose(0, 1)
94
+ else:
95
+ num_of_boxes = self.num_objects
96
+ appearance = None
97
+
98
+ query_pos_emb = self.pos_emb(
99
+ bs, self.kernel_dim, self.kernel_dim, f_e.device
100
+ ).flatten(2).permute(2, 0, 1).repeat(num_of_boxes, 1, 1)
101
+
102
+ if self.num_iterative_steps > 0:
103
+ memory = f_e.flatten(2).permute(2, 0, 1)
104
+ all_prototypes = self.iterative_adaptation(
105
+ shape_or_objectness, appearance, memory, pos_emb, query_pos_emb
106
+ )
107
+ else:
108
+ if shape_or_objectness is not None and appearance is not None:
109
+ all_prototypes = (shape_or_objectness + appearance).unsqueeze(0)
110
+ else:
111
+ all_prototypes = (
112
+ shape_or_objectness if shape_or_objectness is not None else appearance
113
+ ).unsqueeze(0)
114
+
115
+ return all_prototypes
116
+
117
+
118
+
119
+ class IterativeAdaptationModule(nn.Module):
120
+
121
+ def __init__(
122
+ self,
123
+ num_layers: int,
124
+ emb_dim: int,
125
+ num_heads: int,
126
+ dropout: float,
127
+ layer_norm_eps: float,
128
+ mlp_factor: int,
129
+ norm_first: bool,
130
+ activation: nn.Module,
131
+ norm: bool,
132
+ zero_shot: bool
133
+ ):
134
+
135
+ super(IterativeAdaptationModule, self).__init__()
136
+
137
+ self.layers = nn.ModuleList([
138
+ IterativeAdaptationLayer(
139
+ emb_dim, num_heads, dropout, layer_norm_eps,
140
+ mlp_factor, norm_first, activation, zero_shot
141
+ ) for i in range(num_layers)
142
+ ])
143
+
144
+ self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
145
+
146
+ def forward(
147
+ self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask=None, memory_mask=None,
148
+ tgt_key_padding_mask=None, memory_key_padding_mask=None
149
+ ):
150
+
151
+ output = tgt
152
+ outputs = list()
153
+ for i, layer in enumerate(self.layers):
154
+ output = layer(
155
+ output, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask,
156
+ tgt_key_padding_mask, memory_key_padding_mask
157
+ )
158
+ outputs.append(self.norm(output))
159
+
160
+ return torch.stack(outputs)
161
+
162
+
163
+ class IterativeAdaptationLayer(nn.Module):
164
+
165
+ def __init__(
166
+ self,
167
+ emb_dim: int,
168
+ num_heads: int,
169
+ dropout: float,
170
+ layer_norm_eps: float,
171
+ mlp_factor: int,
172
+ norm_first: bool,
173
+ activation: nn.Module,
174
+ zero_shot: bool
175
+ ):
176
+ super(IterativeAdaptationLayer, self).__init__()
177
+
178
+ self.norm_first = norm_first
179
+ self.zero_shot = zero_shot
180
+
181
+ if not self.zero_shot:
182
+ self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
183
+ self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
184
+ self.norm3 = nn.LayerNorm(emb_dim, layer_norm_eps)
185
+ if not self.zero_shot:
186
+ self.dropout1 = nn.Dropout(dropout)
187
+ self.dropout2 = nn.Dropout(dropout)
188
+ self.dropout3 = nn.Dropout(dropout)
189
+
190
+ if not self.zero_shot:
191
+ self.self_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout)
192
+ self.enc_dec_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout)
193
+
194
+ self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
195
+
196
+ def with_emb(self, x, emb):
197
+ return x if emb is None else x + emb
198
+
199
+ def forward(
200
+ self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask,
201
+ tgt_key_padding_mask, memory_key_padding_mask
202
+ ):
203
+ if self.norm_first:
204
+ if not self.zero_shot:
205
+ tgt_norm = self.norm1(tgt)
206
+ tgt = tgt + self.dropout1(self.self_attn(
207
+ query=self.with_emb(tgt_norm, query_pos_emb),
208
+ key=self.with_emb(appearance, query_pos_emb),
209
+ value=appearance,
210
+ attn_mask=tgt_mask,
211
+ key_padding_mask=tgt_key_padding_mask
212
+ )[0])
213
+
214
+ tgt_norm = self.norm2(tgt)
215
+ tgt = tgt + self.dropout2(self.enc_dec_attn(
216
+ query=self.with_emb(tgt_norm, query_pos_emb),
217
+ key=memory+pos_emb,
218
+ value=memory,
219
+ attn_mask=memory_mask,
220
+ key_padding_mask=memory_key_padding_mask
221
+ )[0])
222
+ tgt_norm = self.norm3(tgt)
223
+ tgt = tgt + self.dropout3(self.mlp(tgt_norm))
224
+
225
+ else:
226
+ if not self.zero_shot:
227
+ tgt = self.norm1(tgt + self.dropout1(self.self_attn(
228
+ query=self.with_emb(tgt, query_pos_emb),
229
+ key=self.with_emb(appearance),
230
+ value=appearance,
231
+ attn_mask=tgt_mask,
232
+ key_padding_mask=tgt_key_padding_mask
233
+ )[0]))
234
+
235
+ tgt = self.norm2(tgt + self.dropout2(self.enc_dec_attn(
236
+ query=self.with_emb(tgt, query_pos_emb),
237
+ key=memory+pos_emb,
238
+ value=memory,
239
+ attn_mask=memory_mask,
240
+ key_padding_mask=memory_key_padding_mask
241
+ )[0]))
242
+
243
+ tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
244
+
245
+ return tgt
models/enc_model/positional_encoding.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class PositionalEncodingsFixed(nn.Module):
6
+
7
+ def __init__(self, emb_dim, temperature=10000):
8
+
9
+ super(PositionalEncodingsFixed, self).__init__()
10
+
11
+ self.emb_dim = emb_dim
12
+ self.temperature = temperature
13
+
14
+ def _1d_pos_enc(self, mask, dim):
15
+ temp = torch.arange(self.emb_dim // 2).float().to(mask.device)
16
+ temp = self.temperature ** (2 * (temp.div(2, rounding_mode='floor')) / self.emb_dim)
17
+
18
+ enc = (~mask).cumsum(dim).float().unsqueeze(-1) / temp
19
+ enc = torch.stack([
20
+ enc[..., 0::2].sin(), enc[..., 1::2].cos()
21
+ ], dim=-1).flatten(-2)
22
+
23
+ return enc
24
+
25
+ def forward(self, bs, h, w, device):
26
+ mask = torch.zeros(bs, h, w, dtype=torch.bool, requires_grad=False, device=device)
27
+ x = self._1d_pos_enc(mask, dim=2)
28
+ y = self._1d_pos_enc(mask, dim=1)
29
+
30
+ return torch.cat([y, x], dim=3).permute(0, 3, 1, 2)
models/enc_model/regression_head.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+
5
+ class UpsamplingLayer(nn.Module):
6
+
7
+ def __init__(self, in_channels, out_channels, leaky=True):
8
+
9
+ super(UpsamplingLayer, self).__init__()
10
+
11
+ self.layer = nn.Sequential(
12
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
13
+ nn.LeakyReLU() if leaky else nn.ReLU(),
14
+ nn.UpsamplingBilinear2d(scale_factor=2)
15
+ )
16
+
17
+ def forward(self, x):
18
+ return self.layer(x)
19
+
20
+
21
+ class DensityMapRegressor(nn.Module):
22
+
23
+ def __init__(self, in_channels, reduction):
24
+
25
+ super(DensityMapRegressor, self).__init__()
26
+
27
+ if reduction == 8:
28
+ self.regressor = nn.Sequential(
29
+ UpsamplingLayer(in_channels, 128),
30
+ UpsamplingLayer(128, 64),
31
+ UpsamplingLayer(64, 32),
32
+ nn.Conv2d(32, 1, kernel_size=1),
33
+ nn.LeakyReLU()
34
+ )
35
+ elif reduction == 16:
36
+ self.regressor = nn.Sequential(
37
+ UpsamplingLayer(in_channels, 128),
38
+ UpsamplingLayer(128, 64),
39
+ UpsamplingLayer(64, 32),
40
+ UpsamplingLayer(32, 16),
41
+ nn.Conv2d(16, 1, kernel_size=1),
42
+ nn.LeakyReLU()
43
+ )
44
+
45
+ self.reset_parameters()
46
+
47
+ def forward(self, x):
48
+ return self.regressor(x)
49
+
50
+ def reset_parameters(self):
51
+ for module in self.modules():
52
+ if isinstance(module, nn.Conv2d):
53
+ nn.init.normal_(module.weight, std=0.01)
54
+ if module.bias is not None:
55
+ nn.init.constant_(module.bias, 0)
56
+
57
+
58
+ class DensityMapRegressor_(nn.Module):
59
+
60
+ def __init__(self, in_channels, reduction):
61
+
62
+ super(DensityMapRegressor, self).__init__()
63
+
64
+ if reduction == 8:
65
+ self.regressor = nn.Sequential(
66
+ UpsamplingLayer(in_channels, 128),
67
+ UpsamplingLayer(128, 64),
68
+ UpsamplingLayer(64, 32),
69
+ nn.Conv2d(32, 1, kernel_size=1),
70
+ nn.LeakyReLU()
71
+ )
72
+ elif reduction == 16:
73
+ self.regressor = nn.Sequential(
74
+ UpsamplingLayer(in_channels, 128),
75
+ UpsamplingLayer(128, 64),
76
+ UpsamplingLayer(64, 32),
77
+ UpsamplingLayer(32, 16),
78
+ nn.Conv2d(16, 1, kernel_size=1),
79
+ nn.LeakyReLU()
80
+ )
81
+
82
+ self.reset_parameters()
83
+
84
+ def forward(self, x):
85
+ return self.regressor(x)
86
+
87
+ def reset_parameters(self):
88
+ for module in self.modules():
89
+ if isinstance(module, nn.Conv2d):
90
+ nn.init.normal_(module.weight, std=0.01)
91
+ if module.bias is not None:
92
+ nn.init.constant_(module.bias, 0)
models/enc_model/transformer.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mlp import MLP
2
+
3
+ from torch import nn
4
+
5
+
6
+ class TransformerEncoder(nn.Module):
7
+
8
+ def __init__(
9
+ self,
10
+ num_layers: int,
11
+ emb_dim: int,
12
+ num_heads: int,
13
+ dropout: float,
14
+ layer_norm_eps: float,
15
+ mlp_factor: int,
16
+ norm_first: bool,
17
+ activation: nn.Module,
18
+ norm: bool,
19
+ ):
20
+
21
+ super(TransformerEncoder, self).__init__()
22
+
23
+ self.layers = nn.ModuleList([
24
+ TransformerEncoderLayer(
25
+ emb_dim, num_heads, dropout, layer_norm_eps,
26
+ mlp_factor, norm_first, activation
27
+ ) for _ in range(num_layers)
28
+ ])
29
+
30
+ self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
31
+
32
+ def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
33
+ output = src
34
+ for layer in self.layers:
35
+ output = layer(output, pos_emb, src_mask, src_key_padding_mask)
36
+ return self.norm(output)
37
+
38
+
39
+ class TransformerEncoderLayer(nn.Module):
40
+
41
+ def __init__(
42
+ self,
43
+ emb_dim: int,
44
+ num_heads: int,
45
+ dropout: float,
46
+ layer_norm_eps: float,
47
+ mlp_factor: int,
48
+ norm_first: bool,
49
+ activation: nn.Module,
50
+ ):
51
+ super(TransformerEncoderLayer, self).__init__()
52
+
53
+ self.norm_first = norm_first
54
+
55
+ self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
56
+ self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
57
+ self.dropout1 = nn.Dropout(dropout)
58
+ self.dropout2 = nn.Dropout(dropout)
59
+
60
+ self.self_attn = nn.MultiheadAttention(
61
+ emb_dim, num_heads, dropout
62
+ )
63
+ self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
64
+
65
+ def with_emb(self, x, emb):
66
+ return x if emb is None else x + emb
67
+
68
+ def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
69
+ if self.norm_first:
70
+ src_norm = self.norm1(src)
71
+ q = k = src_norm + pos_emb
72
+ src = src + self.dropout1(self.self_attn(
73
+ query=q,
74
+ key=k,
75
+ value=src_norm,
76
+ attn_mask=src_mask,
77
+ key_padding_mask=src_key_padding_mask
78
+ )[0])
79
+
80
+ src_norm = self.norm2(src)
81
+ src = src + self.dropout2(self.mlp(src_norm))
82
+ else:
83
+ q = k = src + pos_emb
84
+ src = self.norm1(src + self.dropout1(self.self_attn(
85
+ query=q,
86
+ key=k,
87
+ value=src,
88
+ attn_mask=src_mask,
89
+ key_padding_mask=src_key_padding_mask
90
+ )[0]))
91
+
92
+ src = self.norm2(src + self.dropout2(self.mlp(src)))
93
+
94
+ return src
models/enc_model/unet_parts.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Parts of the U-Net model """
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class DoubleConv(nn.Module):
9
+ """(convolution => [BN] => ReLU) * 2"""
10
+
11
+ def __init__(self, in_channels, out_channels, mid_channels=None):
12
+ super().__init__()
13
+ if not mid_channels:
14
+ mid_channels = out_channels
15
+ self.double_conv = nn.Sequential(
16
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
17
+ nn.BatchNorm2d(mid_channels),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
20
+ nn.BatchNorm2d(out_channels),
21
+ nn.ReLU(inplace=True)
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.double_conv(x)
26
+
27
+
28
+ class Down(nn.Module):
29
+ """Downscaling with maxpool then double conv"""
30
+
31
+ def __init__(self, in_channels, out_channels):
32
+ super().__init__()
33
+ self.maxpool_conv = nn.Sequential(
34
+ nn.MaxPool2d(2),
35
+ DoubleConv(in_channels, out_channels)
36
+ )
37
+
38
+ def forward(self, x):
39
+ return self.maxpool_conv(x)
40
+
41
+
42
+ class Up(nn.Module):
43
+ """Upscaling then double conv"""
44
+
45
+ def __init__(self, in_channels, out_channels, bilinear=True):
46
+ super().__init__()
47
+
48
+ # if bilinear, use the normal convolutions to reduce the number of channels
49
+ if bilinear:
50
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52
+ else:
53
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
54
+ self.conv = DoubleConv(in_channels, out_channels)
55
+
56
+ def forward(self, x1, x2):
57
+ x1 = self.up(x1)
58
+ # input is CHW
59
+ diffY = x2.size()[2] - x1.size()[2]
60
+ diffX = x2.size()[3] - x1.size()[3]
61
+
62
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
63
+ diffY // 2, diffY - diffY // 2])
64
+ # if you have padding issues, see
65
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
66
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
67
+ x = torch.cat([x2, x1], dim=1)
68
+ return self.conv(x)
69
+
70
+
71
+ class OutConv(nn.Module):
72
+ def __init__(self, in_channels, out_channels):
73
+ super(OutConv, self).__init__()
74
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
75
+
76
+ def forward(self, x):
77
+ return self.conv(x)
models/model.py ADDED
@@ -0,0 +1,991 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ import clip
6
+ import sys
7
+ from models.seg_post_model.cellpose.models import CellposeModel
8
+
9
+ from torchvision.ops import roi_align
10
+ def crop_roi_feat(feat, boxes):
11
+ """
12
+ feat: 1 x c x h x w
13
+ boxes: m x 4, 4: [y_tl, x_tl, y_br, x_br]
14
+ """
15
+ _, _, h, w = feat.shape
16
+ out_stride = 512 / h
17
+ boxes_scaled = boxes / out_stride
18
+ boxes_scaled[:, :2] = torch.floor(boxes_scaled[:, :2]) # y_tl, x_tl: floor
19
+ boxes_scaled[:, 2:] = torch.ceil(boxes_scaled[:, 2:]) # y_br, x_br: ceil
20
+ boxes_scaled[:, :2] = torch.clamp_min(boxes_scaled[:, :2], 0)
21
+ boxes_scaled[:, 2] = torch.clamp_max(boxes_scaled[:, 2], h)
22
+ boxes_scaled[:, 3] = torch.clamp_max(boxes_scaled[:, 3], w)
23
+ feat_boxes = []
24
+ for idx_box in range(0, boxes.shape[0]):
25
+ y_tl, x_tl, y_br, x_br = boxes_scaled[idx_box]
26
+ y_tl, x_tl, y_br, x_br = int(y_tl), int(x_tl), int(y_br), int(x_br)
27
+ feat_box = feat[:, :, y_tl : (y_br + 1), x_tl : (x_br + 1)]
28
+ feat_boxes.append(feat_box)
29
+ return feat_boxes
30
+
31
+ class Counting_with_SD_features(nn.Module):
32
+ def __init__(self, scale_factor):
33
+ super(Counting_with_SD_features, self).__init__()
34
+ self.adapter = adapter_roi()
35
+ # self.regressor = regressor_with_SD_features()
36
+
37
+ class Counting_with_SD_features_loca(nn.Module):
38
+ def __init__(self, scale_factor):
39
+ super(Counting_with_SD_features_loca, self).__init__()
40
+ self.adapter = adapter_roi_loca()
41
+ self.regressor = regressor_with_SD_features()
42
+
43
+
44
+ class Counting_with_SD_features_dino_vit_c3(nn.Module):
45
+ def __init__(self, scale_factor, vit=None):
46
+ super(Counting_with_SD_features_dino_vit_c3, self).__init__()
47
+ self.adapter = adapter_roi_loca()
48
+ self.regressor = regressor_with_SD_features_seg_vit_c3()
49
+
50
+ class Counting_with_SD_features_track(nn.Module):
51
+ def __init__(self, scale_factor, vit=None):
52
+ super(Counting_with_SD_features_track, self).__init__()
53
+ self.adapter = adapter_roi_loca()
54
+ self.regressor = regressor_with_SD_features_tra()
55
+
56
+ class Counting_with_SD_features_loca_rand(nn.Module):
57
+ def __init__(self, scale_factor, num_of_roi = 3):
58
+ super(Counting_with_SD_features_loca_rand, self).__init__()
59
+ self.adapter = adapter_roi_loca_rand(num_of_roi=num_of_roi)
60
+ self.regressor = regressor_with_SD_features()
61
+
62
+ class Counting_with_SD_features_loca_carpk(nn.Module):
63
+ def __init__(self, scale_factor, num_of_roi = 3):
64
+ super(Counting_with_SD_features_loca_carpk, self).__init__()
65
+ self.adapter = adapter_roi_loca_carpk(num_of_roi=num_of_roi)
66
+ self.regressor = regressor_with_SD_features()
67
+
68
+ class Counting_with_SD_features_clip_carpk(nn.Module):
69
+ def __init__(self, scale_factor, num_of_roi = 3):
70
+ super(Counting_with_SD_features_clip_carpk, self).__init__()
71
+ self.adapter = adapter_roi_clip_carpk(num_of_roi=num_of_roi)
72
+ # self.regressor = regressor_with_SD_features()
73
+
74
+ class Counting_with_SD_features_zero(nn.Module):
75
+ def __init__(self, scale_factor):
76
+ super(Counting_with_SD_features_zero, self).__init__()
77
+ self.adapter = adapter_roi_zero()
78
+ self.regressor = regressor_with_SD_features()
79
+
80
+ class Counting_with_SD_features_zero_loca(nn.Module):
81
+ def __init__(self, scale_factor):
82
+ super(Counting_with_SD_features_zero_loca, self).__init__()
83
+ self.adapter = adapter_roi_zero_loca()
84
+ self.regressor = regressor_with_SD_features()
85
+
86
+ class Counting_with_SD_features_zero_loca_self(nn.Module):
87
+ def __init__(self, scale_factor):
88
+ super(Counting_with_SD_features_zero_loca_self, self).__init__()
89
+ self.adapter = adapter_roi_zero_loca()
90
+ # self.regressor = regressor_with_SD_features_self()
91
+ self.regressor = regressor_with_SD_features_latent()
92
+
93
+ class Counting_with_SD_features_loca_v2(nn.Module):
94
+ def __init__(self, scale_factor):
95
+ super(Counting_with_SD_features_loca_v2, self).__init__()
96
+ self.adapter = adapter_roi_loca_v2()
97
+ # self.regressor = regressor_with_SD_features()
98
+
99
+ class adapter1(nn.Module):
100
+ def __init__(self):
101
+ super(adapter1, self).__init__()
102
+ self.conv1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
103
+ self.pool = nn.MaxPool2d(2)
104
+ self.fc = nn.Linear(128 * 64 * 64, 768)
105
+ self.initialize_weights()
106
+
107
+ def forward(self, x):
108
+ x = self.conv1(x)
109
+ x = self.pool(x)
110
+ x = x.view(x.size(0), -1)
111
+ x = self.fc(x)
112
+ return x
113
+
114
+ def initialize_weights(self):
115
+ for m in self.modules():
116
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
117
+ nn.init.xavier_normal_(m.weight)
118
+ if m.bias is not None:
119
+ nn.init.constant_(m.bias, 0)
120
+
121
+ class adapter(nn.Module):
122
+ def __init__(self, pool_size=[3, 3]):
123
+ super(adapter, self).__init__()
124
+ self.pool_size = pool_size
125
+ self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
126
+ self.pool = nn.MaxPool2d(2)
127
+ self.fc = nn.Linear(256 * 3 * 3, 768)
128
+ self.initialize_weights()
129
+
130
+ def forward(self, xs):
131
+ x_list = []
132
+ for x in xs:
133
+ x = F.adaptive_max_pool2d(x, self.pool_size, return_indices=False) # [1, 256, 3, 3]
134
+ x_list.append(x)
135
+ x_list = torch.cat(x_list, dim=0)
136
+ x_list = torch.mean(x_list, dim=0, keepdim=True) # [1, 256, 3, 3]
137
+ x = self.conv1(x_list)
138
+ # x = self.pool(x)
139
+ x = x.view(x.size(0), -1)
140
+ x = self.fc(x)
141
+ return x
142
+
143
+ def initialize_weights(self):
144
+ for m in self.modules():
145
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
146
+ nn.init.xavier_normal_(m.weight)
147
+ if m.bias is not None:
148
+ nn.init.constant_(m.bias, 0)
149
+
150
+ class adapter_roi(nn.Module):
151
+ def __init__(self, pool_size=[3, 3]):
152
+ super(adapter_roi, self).__init__()
153
+ self.pool_size = pool_size
154
+ self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
155
+ # self.relu = nn.ReLU()
156
+ # self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
157
+ self.pool = nn.MaxPool2d(2)
158
+ self.fc = nn.Linear(256 * 3 * 3, 768)
159
+ # **new
160
+ self.fc1 = nn.Sequential(
161
+ nn.ReLU(),
162
+ nn.Linear(768, 768 // 4, bias=False),
163
+ nn.ReLU()
164
+ )
165
+ self.fc2 = nn.Sequential(
166
+ nn.Linear(768 // 4, 768, bias=False),
167
+ # nn.ReLU()
168
+ )
169
+ self.initialize_weights()
170
+
171
+ def forward(self, x, boxes):
172
+ num_of_boxes = boxes.shape[1]
173
+ rois = []
174
+ bs, _, h, w = x.shape
175
+ boxes = torch.cat([
176
+ torch.arange(
177
+ bs, requires_grad=False
178
+ ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
179
+ boxes.flatten(0, 1),
180
+ ], dim=1)
181
+ rois = roi_align(
182
+ x,
183
+ boxes=boxes, output_size=3,
184
+ spatial_scale=1.0 / 8, aligned=True
185
+ )
186
+ rois = torch.mean(rois, dim=0, keepdim=True)
187
+ x = self.conv1(rois)
188
+ x = x.view(x.size(0), -1)
189
+ x = self.fc(x)
190
+
191
+ x = self.fc1(x)
192
+ x = self.fc2(x)
193
+ return x
194
+
195
+
196
+ def initialize_weights(self):
197
+ for m in self.modules():
198
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
199
+ nn.init.xavier_normal_(m.weight)
200
+ if m.bias is not None:
201
+ nn.init.constant_(m.bias, 0)
202
+
203
+
204
+ class adapter_roi_loca(nn.Module):
205
+ def __init__(self, pool_size=[3, 3]):
206
+ super(adapter_roi_loca, self).__init__()
207
+ self.pool_size = pool_size
208
+ self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
209
+ self.pool = nn.MaxPool2d(2)
210
+ self.fc = nn.Linear(256 * 3 * 3, 768)
211
+ self.initialize_weights()
212
+ def forward(self, x, boxes):
213
+ num_of_boxes = boxes.shape[1]
214
+ rois = []
215
+ bs, _, h, w = x.shape
216
+ if h != 512 or w != 512:
217
+ x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False)
218
+ if bs == 1:
219
+ boxes = torch.cat([
220
+ torch.arange(
221
+ bs, requires_grad=False
222
+ ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
223
+ boxes.flatten(0, 1),
224
+ ], dim=1)
225
+ rois = roi_align(
226
+ x,
227
+ boxes=boxes, output_size=3,
228
+ spatial_scale=1.0 / 8, aligned=True
229
+ )
230
+ rois = torch.mean(rois, dim=0, keepdim=True)
231
+ else:
232
+ boxes = torch.cat([
233
+ boxes.flatten(0, 1),
234
+ ], dim=1).split(num_of_boxes, dim=0)
235
+ rois = roi_align(
236
+ x,
237
+ boxes=boxes, output_size=3,
238
+ spatial_scale=1.0 / 8, aligned=True
239
+ )
240
+ rois = rois.split(num_of_boxes, dim=0)
241
+ rois = torch.stack(rois, dim=0)
242
+ rois = torch.mean(rois, dim=1, keepdim=False)
243
+ x = self.conv1(rois)
244
+ x = x.view(x.size(0), -1)
245
+ x = self.fc(x)
246
+ return x
247
+
248
+ def forward_boxes(self, x, boxes):
249
+ num_of_boxes = boxes.shape[1]
250
+ rois = []
251
+ bs, _, h, w = x.shape
252
+ if h != 512 or w != 512:
253
+ x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False)
254
+ if bs == 1:
255
+ boxes = torch.cat([
256
+ torch.arange(
257
+ bs, requires_grad=False
258
+ ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
259
+ boxes.flatten(0, 1),
260
+ ], dim=1)
261
+ rois = roi_align(
262
+ x,
263
+ boxes=boxes, output_size=3,
264
+ spatial_scale=1.0 / 8, aligned=True
265
+ )
266
+ # rois = torch.mean(rois, dim=0, keepdim=True)
267
+ else:
268
+ raise NotImplementedError
269
+ x = self.conv1(rois)
270
+ x = x.view(x.size(0), -1)
271
+ x = self.fc(x)
272
+ return x
273
+
274
+ def initialize_weights(self):
275
+ for m in self.modules():
276
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
277
+ nn.init.xavier_normal_(m.weight)
278
+ if m.bias is not None:
279
+ nn.init.constant_(m.bias, 0)
280
+
281
+
282
+ class adapter_roi_dino(nn.Module):
283
+ def __init__(self, pool_size=[3, 3]):
284
+ super(adapter_roi_dino, self).__init__()
285
+ self.pool_size = pool_size
286
+ # self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
287
+ # self.pool = nn.MaxPool2d(2)
288
+ self.fc = nn.Linear(1024, 768)
289
+ self.initialize_weights()
290
+ def forward(self, crops, dino_model):
291
+ num_of_boxes = len(crops)
292
+ feats = []
293
+ for i in range(num_of_boxes):
294
+ with torch.no_grad():
295
+ feat = dino_model(crops[i])
296
+
297
+ feats.append(feat)
298
+ feats = torch.cat(feats, dim=0)
299
+ feats = torch.mean(feats, dim=0)
300
+ x = self.fc(feats)
301
+ return x
302
+ def initialize_weights(self):
303
+ for m in self.modules():
304
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
305
+ nn.init.xavier_normal_(m.weight)
306
+ if m.bias is not None:
307
+ nn.init.constant_(m.bias, 0)
308
+
309
+
310
+
311
+ class adapter_roi_loca_v2(nn.Module):
312
+ def __init__(self, pool_size=[3, 3]):
313
+ super(adapter_roi_loca_v2, self).__init__()
314
+ self.pool_size = pool_size
315
+ self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
316
+ self.pool = nn.MaxPool2d(2)
317
+ self.fc = nn.Linear(256 * 3 * 3, 1024)
318
+ self.initialize_weights()
319
+ def forward(self, x, boxes):
320
+ rois = []
321
+ bs, _, h, w = x.shape
322
+ boxes = torch.cat([
323
+ torch.arange(
324
+ bs, requires_grad=False
325
+ ).to(boxes.device).repeat_interleave(3).reshape(-1, 1),
326
+ boxes.flatten(0, 1),
327
+ ], dim=1)
328
+ rois = roi_align(
329
+ x,
330
+ boxes=boxes, output_size=3,
331
+ spatial_scale=1.0 / 8, aligned=True
332
+ )
333
+ rois = torch.mean(rois, dim=0, keepdim=True)
334
+ x = self.conv1(rois)
335
+ x = x.view(x.size(0), -1)
336
+ x = self.fc(x)
337
+ return x
338
+ def initialize_weights(self):
339
+ for m in self.modules():
340
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
341
+ nn.init.xavier_normal_(m.weight)
342
+ if m.bias is not None:
343
+ nn.init.constant_(m.bias, 0)
344
+
345
+ class adapter_roi_zero(nn.Module):
346
+ def __init__(self, reduction=4):
347
+ super(adapter_roi_zero, self).__init__()
348
+ self.fc1 = nn.Sequential(
349
+ nn.Linear(768, 768 // reduction, bias=False),
350
+ nn.ReLU()
351
+ )
352
+ self.fc2 = nn.Sequential(
353
+ nn.Linear(768 // reduction, 768, bias=False),
354
+ nn.ReLU()
355
+ )
356
+ self.initialize_weights()
357
+ def forward(self, x):
358
+ x1 = self.fc1(x)
359
+ x1 = self.fc2(x1)
360
+ return x + x1
361
+ def initialize_weights(self):
362
+ for m in self.modules():
363
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
364
+ nn.init.xavier_normal_(m.weight)
365
+ if m.bias is not None:
366
+ nn.init.constant_(m.bias, 0)
367
+
368
+ class adapter_roi_zero_loca(nn.Module):
369
+ def __init__(self, reduction=4):
370
+ super(adapter_roi_zero_loca, self).__init__()
371
+ self.fc1 = nn.Sequential(
372
+ nn.Linear(768, 768 // reduction, bias=False),
373
+ nn.ReLU()
374
+ )
375
+ self.fc2 = nn.Sequential(
376
+ nn.Linear(768 // reduction, 768, bias=False),
377
+ nn.ReLU()
378
+ )
379
+
380
+ self.pool_size = (3, 3)
381
+ self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
382
+ self.pool = nn.MaxPool2d(2)
383
+ self.fc = nn.Linear(256 * 3 * 3, 768)
384
+
385
+ self.initialize_weights()
386
+ def forward(self, feature, boxes, class_emb):
387
+ x1 = self.fc1(class_emb)
388
+ x1 = self.fc2(x1)
389
+ class_emb = class_emb + x1
390
+
391
+ rois = []
392
+ bs, _, h, w = feature.shape
393
+ n_box = boxes.shape[1]
394
+ boxes = torch.cat([
395
+ torch.arange(
396
+ bs, requires_grad=False
397
+ ).to(boxes.device).repeat_interleave(n_box).reshape(-1, 1),
398
+ boxes.flatten(0, 1),
399
+ ], dim=1)
400
+ rois = roi_align(
401
+ feature,
402
+ boxes=boxes, output_size=3,
403
+ spatial_scale=1.0 / 8, aligned=True
404
+ )
405
+ # rois = torch.mean(rois, dim=0, keepdim=True)
406
+ x = self.conv1(rois)
407
+ x = x.view(x.size(0), -1)
408
+ x = self.fc(x)
409
+
410
+ if len(class_emb.shape) == 3:
411
+ class_emb = class_emb.squeeze(1)
412
+ dist = torch.cosine_similarity(class_emb, x) # [n_box]
413
+ _, topk = torch.sort(dist[:10])
414
+ x_topk = x[topk[:3], :]
415
+ x_topk = torch.mean(x_topk, dim=0, keepdim=True)
416
+ return x_topk + class_emb
417
+
418
+ def vis(self, feature, boxes, class_emb):
419
+ x1 = self.fc1(class_emb)
420
+ x1 = self.fc2(x1)
421
+ class_emb = class_emb + x1
422
+
423
+ rois = []
424
+ bs, _, h, w = feature.shape
425
+ n_box = boxes.shape[1]
426
+ boxes = torch.cat([
427
+ torch.arange(
428
+ bs, requires_grad=False
429
+ ).to(boxes.device).repeat_interleave(n_box).reshape(-1, 1),
430
+ boxes.flatten(0, 1),
431
+ ], dim=1)
432
+ rois = roi_align(
433
+ feature,
434
+ boxes=boxes, output_size=3,
435
+ spatial_scale=1.0 / 8, aligned=True
436
+ )
437
+ # rois = torch.mean(rois, dim=0, keepdim=True)
438
+ x = self.conv1(rois)
439
+ x = x.view(x.size(0), -1)
440
+ x = self.fc(x)
441
+
442
+ if len(class_emb.shape) == 3:
443
+ class_emb = class_emb.squeeze(1)
444
+ dist = torch.cosine_similarity(class_emb, x) # [n_box]
445
+ _, topk = torch.sort(dist[:10])
446
+ x_topk = x[topk[:3], :]
447
+ x_topk = torch.mean(x_topk, dim=0, keepdim=True)
448
+ return x_topk
449
+
450
+ def initialize_weights(self):
451
+ for m in self.modules():
452
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
453
+ nn.init.xavier_normal_(m.weight)
454
+ if m.bias is not None:
455
+ nn.init.constant_(m.bias, 0)
456
+
457
+ class adapter_roi_loca_rand(nn.Module):
458
+ def __init__(self, pool_size=[3, 3],num_of_roi = 3):
459
+ super(adapter_roi_loca_rand, self).__init__()
460
+ self.pool_size = pool_size
461
+ self.num_of_roi = num_of_roi
462
+ self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
463
+ self.pool = nn.MaxPool2d(2)
464
+ self.fc = nn.Linear(256 * 3 * 3, 768)
465
+
466
+ # # **new
467
+ # self.fc1 = nn.Sequential(
468
+ # nn.Linear(768, 768 // 4, bias=False),
469
+ # nn.ReLU()
470
+ # )
471
+ # self.fc2 = nn.Sequential(
472
+ # nn.Linear(768 // 4, 768, bias=False),
473
+ # nn.ReLU()
474
+ # )
475
+ # #
476
+ self.initialize_weights()
477
+ def forward(self, x, boxes, rand_boxes):
478
+ num_of_boxes = boxes.shape[1]
479
+ bs, _, h, w = x.shape
480
+ boxes = torch.cat([
481
+ torch.arange(
482
+ bs, requires_grad=False
483
+ ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
484
+ boxes.flatten(0, 1),
485
+ ], dim=1)
486
+ rois = roi_align(
487
+ x,
488
+ boxes=boxes, output_size=3,
489
+ spatial_scale=1.0 / 8, aligned=True
490
+ )
491
+
492
+ # new
493
+ num_of_boxes = rand_boxes.shape[1]
494
+ bs, _, h, w = x.shape
495
+ rand_boxes = torch.cat([
496
+ torch.arange(
497
+ bs, requires_grad=False
498
+ ).to(rand_boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
499
+ rand_boxes.flatten(0, 1),
500
+ ], dim=1)
501
+ rand_rois = roi_align(
502
+ x,
503
+ boxes=rand_boxes, output_size=3,
504
+ spatial_scale=1.0 / 8, aligned=True
505
+ )
506
+
507
+ rois = torch.mean(rois, dim=0, keepdim=True)
508
+
509
+ # new
510
+ cos = torch.nn.CosineSimilarity(dim=1)
511
+ dist = cos(rois.view(1, -1), rand_rois.view(num_of_boxes, -1)) # [n_box]
512
+ _, topk = torch.sort(-dist)
513
+ x_topk = rand_rois[topk[:3], ...]
514
+ x_topk = torch.mean(x_topk, dim=0, keepdim=True)
515
+
516
+ rois += x_topk
517
+
518
+ x = self.conv1(rois)
519
+ x = x.view(x.size(0), -1)
520
+ x = self.fc(x)
521
+ # new
522
+ # x = self.fc1(x)
523
+ # x = self.fc2(x)
524
+ return x
525
+
526
+ def initialize_weights(self):
527
+ for m in self.modules():
528
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
529
+ nn.init.xavier_normal_(m.weight)
530
+ if m.bias is not None:
531
+ nn.init.constant_(m.bias, 0)
532
+
533
+
534
+ class regressor1(nn.Module):
535
+ def __init__(self):
536
+ super(regressor1, self).__init__()
537
+ self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
538
+ self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
539
+ self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
540
+ self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
541
+ self.leaky_relu = nn.LeakyReLU()
542
+ self.relu = nn.ReLU()
543
+ self.initialize_weights()
544
+
545
+
546
+
547
+ def forward(self, x):
548
+ x_ = self.conv1(x)
549
+ x_ = self.leaky_relu(x_)
550
+ x_ = self.upsampler(x_)
551
+ x_ = self.conv2(x_)
552
+ x_ = self.leaky_relu(x_)
553
+ x_ = self.upsampler(x_)
554
+ x_ = self.conv3(x_)
555
+ x_ = self.relu(x_)
556
+ out = x_
557
+ return out
558
+
559
+ def initialize_weights(self):
560
+ for m in self.modules():
561
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
562
+ nn.init.xavier_normal_(m.weight)
563
+ if m.bias is not None:
564
+ nn.init.constant_(m.bias, 0)
565
+
566
+
567
+ class regressor1(nn.Module):
568
+ def __init__(self):
569
+ super(regressor1, self).__init__()
570
+ self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
571
+ self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
572
+ self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
573
+ self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
574
+ self.leaky_relu = nn.LeakyReLU()
575
+ self.relu = nn.ReLU()
576
+
577
+ def forward(self, x):
578
+ x_ = self.conv1(x)
579
+ x_ = self.leaky_relu(x_)
580
+ x_ = self.upsampler(x_)
581
+ x_ = self.conv2(x_)
582
+ x_ = self.leaky_relu(x_)
583
+ x_ = self.upsampler(x_)
584
+ x_ = self.conv3(x_)
585
+ x_ = self.relu(x_)
586
+ out = x_
587
+ return out
588
+ def initialize_weights(self):
589
+ for m in self.modules():
590
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
591
+ nn.init.xavier_normal_(m.weight)
592
+ if m.bias is not None:
593
+ nn.init.constant_(m.bias, 0)
594
+
595
+
596
+ class regressor_with_SD_features(nn.Module):
597
+ def __init__(self):
598
+ super(regressor_with_SD_features, self).__init__()
599
+ self.layer1 = nn.Sequential(
600
+ nn.Conv2d(324, 256, kernel_size=1, stride=1),
601
+ nn.LeakyReLU(),
602
+ nn.LayerNorm((64, 64))
603
+ )
604
+ self.layer2 = nn.Sequential(
605
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
606
+ nn.LeakyReLU(),
607
+ nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
608
+ )
609
+ self.layer3 = nn.Sequential(
610
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
611
+ nn.ReLU(),
612
+ nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
613
+ )
614
+ self.layer4 = nn.Sequential(
615
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
616
+ nn.LeakyReLU(),
617
+ nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
618
+ )
619
+ self.conv = nn.Sequential(
620
+ nn.Conv2d(32, 1, kernel_size=1),
621
+ nn.ReLU()
622
+ )
623
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
624
+ self.initialize_weights()
625
+
626
+ def forward(self, attn_stack, feature_list):
627
+ attn_stack = self.norm(attn_stack)
628
+ unet_feature = feature_list[-1]
629
+ attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
630
+ unet_feature = unet_feature * attn_stack_mean
631
+ unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
632
+ x = self.layer1(unet_feature)
633
+ x = self.layer2(x)
634
+ x = self.layer3(x)
635
+ x = self.layer4(x)
636
+ out = self.conv(x)
637
+ return out / 100
638
+
639
+ def initialize_weights(self):
640
+ for m in self.modules():
641
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
642
+ nn.init.xavier_normal_(m.weight)
643
+ if m.bias is not None:
644
+ nn.init.constant_(m.bias, 0)
645
+
646
+ class regressor_with_SD_features_seg(nn.Module):
647
+ def __init__(self):
648
+ super(regressor_with_SD_features_seg, self).__init__()
649
+ self.layer1 = nn.Sequential(
650
+ nn.Conv2d(324, 256, kernel_size=1, stride=1),
651
+ nn.LeakyReLU(),
652
+ nn.LayerNorm((64, 64))
653
+ )
654
+ self.layer2 = nn.Sequential(
655
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
656
+ nn.LeakyReLU(),
657
+ nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
658
+ )
659
+ self.layer3 = nn.Sequential(
660
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
661
+ nn.ReLU(),
662
+ nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
663
+ )
664
+ self.layer4 = nn.Sequential(
665
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
666
+ nn.LeakyReLU(),
667
+ nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
668
+ )
669
+ self.conv = nn.Sequential(
670
+ nn.Conv2d(32, 2, kernel_size=1),
671
+ # nn.ReLU()
672
+ )
673
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
674
+ self.initialize_weights()
675
+
676
+ def forward(self, attn_stack, feature_list):
677
+ attn_stack = self.norm(attn_stack)
678
+ unet_feature = feature_list[-1]
679
+ attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
680
+ unet_feature = unet_feature * attn_stack_mean
681
+ unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
682
+ x = self.layer1(unet_feature)
683
+ x = self.layer2(x)
684
+ x = self.layer3(x)
685
+ x = self.layer4(x)
686
+ out = self.conv(x)
687
+ return out
688
+
689
+ def initialize_weights(self):
690
+ for m in self.modules():
691
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
692
+ nn.init.xavier_normal_(m.weight)
693
+ if m.bias is not None:
694
+ nn.init.constant_(m.bias, 0)
695
+
696
+
697
+ from models.enc_model.unet_parts import *
698
+
699
+
700
+ class regressor_with_SD_features_seg_vit_c3(nn.Module):
701
+ def __init__(self, n_channels=3, n_classes=2, bilinear=False):
702
+ super(regressor_with_SD_features_seg_vit_c3, self).__init__()
703
+ self.n_channels = n_channels
704
+ self.n_classes = n_classes
705
+ self.bilinear = bilinear
706
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
707
+ self.inc_0 = nn.Conv2d(n_channels, 3, kernel_size=3, padding=1)
708
+ self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
709
+ self.vit = self.vit_model.net
710
+
711
+ def forward(self, img, attn_stack, feature_list):
712
+ attn_stack = attn_stack[:, [1,3], ...]
713
+ attn_stack = self.norm(attn_stack)
714
+ unet_feature = feature_list[-1]
715
+ unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
716
+
717
+ x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
718
+
719
+ if x.shape[-1] != 512:
720
+ x = F.interpolate(x, size=(512, 512), mode="bilinear")
721
+ x = self.inc_0(x)
722
+
723
+
724
+
725
+ out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
726
+ out = torch.from_numpy(out).unsqueeze(0).to(x.device)
727
+ return out
728
+
729
+ def initialize_weights(self):
730
+ for m in self.modules():
731
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
732
+ nn.init.xavier_normal_(m.weight)
733
+ if m.bias is not None:
734
+ nn.init.constant_(m.bias, 0)
735
+
736
+ class regressor_with_SD_features_tra(nn.Module):
737
+ def __init__(self, n_channels=2, n_classes=2, bilinear=False):
738
+ super(regressor_with_SD_features_tra, self).__init__()
739
+ self.n_channels = n_channels
740
+ self.n_classes = n_classes
741
+ self.bilinear = bilinear
742
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
743
+
744
+ # segmentation
745
+ self.inc_0 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
746
+ self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
747
+ self.vit = self.vit_model.net
748
+
749
+ self.inc_1 = nn.Conv2d(n_channels, 1, kernel_size=3, padding=1)
750
+ self.mlp = nn.Linear(64 * 64, 320)
751
+ # self.vit = self.vit_model.net.float()
752
+
753
+ def forward_seg(self, img, attn_stack, feature_list, mask, training=False):
754
+ attn_stack = attn_stack[:, [1,3], ...]
755
+ attn_stack = self.norm(attn_stack)
756
+ unet_feature = feature_list[-1]
757
+ unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
758
+ x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
759
+
760
+ if x.shape[-1] != 512:
761
+ x = F.interpolate(x, size=(512, 512), mode="bilinear")
762
+ x = self.inc_0(x)
763
+ feat = x
764
+
765
+ out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
766
+ out = torch.from_numpy(out).unsqueeze(0).to(x.device)
767
+ return out, 0., feat
768
+
769
+ def forward(self, attn_prev, feature_list_prev, attn_after, feature_list_after):
770
+ assert attn_prev.shape == attn_after.shape, "attn_prev and attn_after must have the same shape"
771
+ n_instances = attn_prev.shape[0]
772
+ attn_prev = self.norm(attn_prev) # [n_instances, 1, 64, 64]
773
+ attn_after = self.norm(attn_after)
774
+
775
+ x = torch.cat([attn_prev, attn_after], dim=1) # n_instances, 2, 64, 64
776
+
777
+ x = self.inc_1(x)
778
+ x = x.view(1, n_instances, -1) # Flatten the tensor to [n_instances, 64*64*4]
779
+ x = self.mlp(x) # Apply the MLP to get the output
780
+
781
+ return x # Output shape will be [n_instances, 4]
782
+
783
+
784
+
785
+ def initialize_weights(self):
786
+ for m in self.modules():
787
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
788
+ nn.init.xavier_normal_(m.weight)
789
+ if m.bias is not None:
790
+ nn.init.constant_(m.bias, 0)
791
+
792
+
793
+
794
+ class regressor_with_SD_features_inst_seg_unet(nn.Module):
795
+ def __init__(self, n_channels=8, n_classes=3, bilinear=False):
796
+ super(regressor_with_SD_features_inst_seg_unet, self).__init__()
797
+ self.n_channels = n_channels
798
+ self.n_classes = n_classes
799
+ self.bilinear = bilinear
800
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
801
+ self.inc_0 = (DoubleConv(n_channels, 3))
802
+ self.inc = (DoubleConv(3, 64))
803
+ self.down1 = (Down(64, 128))
804
+ self.down2 = (Down(128, 256))
805
+ self.down3 = (Down(256, 512))
806
+ factor = 2 if bilinear else 1
807
+ self.down4 = (Down(512, 1024 // factor))
808
+ self.up1 = (Up(1024, 512 // factor, bilinear))
809
+ self.up2 = (Up(512, 256 // factor, bilinear))
810
+ self.up3 = (Up(256, 128 // factor, bilinear))
811
+ self.up4 = (Up(128, 64, bilinear))
812
+ self.outc = (OutConv(64, n_classes))
813
+
814
+ def forward(self, img, attn_stack, feature_list):
815
+ attn_stack = self.norm(attn_stack)
816
+ unet_feature = feature_list[-1]
817
+ unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
818
+ attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
819
+ unet_feature_mean = unet_feature_mean * attn_stack_mean
820
+ x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
821
+ if x.shape[-1] != 512:
822
+ x = F.interpolate(x, size=(512, 512), mode="bilinear")
823
+ x = torch.cat([img, x], dim=1) # [1, 8, 512, 512]
824
+ x = self.inc_0(x)
825
+ x1 = self.inc(x)
826
+ x2 = self.down1(x1)
827
+ x3 = self.down2(x2)
828
+ x4 = self.down3(x3)
829
+ x5 = self.down4(x4)
830
+ x = self.up1(x5, x4)
831
+ x = self.up2(x, x3)
832
+ x = self.up3(x, x2)
833
+ x = self.up4(x, x1)
834
+ out = self.outc(x)
835
+ return out
836
+
837
+ def initialize_weights(self):
838
+ for m in self.modules():
839
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
840
+ nn.init.xavier_normal_(m.weight)
841
+ if m.bias is not None:
842
+ nn.init.constant_(m.bias, 0)
843
+
844
+
845
+ class regressor_with_SD_features_self(nn.Module):
846
+ def __init__(self):
847
+ super(regressor_with_SD_features_self, self).__init__()
848
+ self.layer = nn.Sequential(
849
+ nn.Conv2d(4096, 1024, kernel_size=1, stride=1),
850
+ nn.LeakyReLU(),
851
+ nn.LayerNorm((64, 64)),
852
+ nn.Conv2d(1024, 256, kernel_size=1, stride=1),
853
+ nn.LeakyReLU(),
854
+ nn.LayerNorm((64, 64)),
855
+ )
856
+ self.layer2 = nn.Sequential(
857
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
858
+ nn.LeakyReLU(),
859
+ nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
860
+ )
861
+ self.layer3 = nn.Sequential(
862
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
863
+ nn.ReLU(),
864
+ nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
865
+ )
866
+ self.layer4 = nn.Sequential(
867
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
868
+ nn.LeakyReLU(),
869
+ nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
870
+ )
871
+ self.conv = nn.Sequential(
872
+ nn.Conv2d(32, 1, kernel_size=1),
873
+ nn.ReLU()
874
+ )
875
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
876
+ self.initialize_weights()
877
+
878
+ def forward(self, self_attn):
879
+ self_attn = self_attn.permute(2, 0, 1)
880
+ self_attn = self.layer(self_attn)
881
+ return self_attn
882
+ # attn_stack = self.norm(attn_stack)
883
+ # unet_feature = feature_list[-1]
884
+ # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
885
+ # unet_feature = unet_feature * attn_stack_mean
886
+ # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
887
+ # x = self.layer(unet_feature)
888
+ # x = self.layer2(x)
889
+ # x = self.layer3(x)
890
+ # x = self.layer4(x)
891
+ # out = self.conv(x)
892
+ # return out / 100
893
+
894
+ def initialize_weights(self):
895
+ for m in self.modules():
896
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
897
+ nn.init.xavier_normal_(m.weight)
898
+ if m.bias is not None:
899
+ nn.init.constant_(m.bias, 0)
900
+
901
+
902
+ class regressor_with_SD_features_latent(nn.Module):
903
+ def __init__(self):
904
+ super(regressor_with_SD_features_latent, self).__init__()
905
+ self.layer = nn.Sequential(
906
+ nn.Conv2d(4, 256, kernel_size=1, stride=1),
907
+ nn.LeakyReLU(),
908
+ nn.LayerNorm((64, 64))
909
+ )
910
+ self.layer2 = nn.Sequential(
911
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
912
+ nn.LeakyReLU(),
913
+ nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
914
+ )
915
+ self.layer3 = nn.Sequential(
916
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
917
+ nn.ReLU(),
918
+ nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
919
+ )
920
+ self.layer4 = nn.Sequential(
921
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
922
+ nn.LeakyReLU(),
923
+ nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
924
+ )
925
+ self.conv = nn.Sequential(
926
+ nn.Conv2d(32, 1, kernel_size=1),
927
+ nn.ReLU()
928
+ )
929
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
930
+ self.initialize_weights()
931
+
932
+ def forward(self, self_attn):
933
+ # self_attn = self_attn.permute(2, 0, 1)
934
+ self_attn = self.layer(self_attn)
935
+ return self_attn
936
+ # attn_stack = self.norm(attn_stack)
937
+ # unet_feature = feature_list[-1]
938
+ # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
939
+ # unet_feature = unet_feature * attn_stack_mean
940
+ # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
941
+ # x = self.layer(unet_feature)
942
+ # x = self.layer2(x)
943
+ # x = self.layer3(x)
944
+ # x = self.layer4(x)
945
+ # out = self.conv(x)
946
+ # return out / 100
947
+
948
+ def initialize_weights(self):
949
+ for m in self.modules():
950
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
951
+ nn.init.xavier_normal_(m.weight)
952
+ if m.bias is not None:
953
+ nn.init.constant_(m.bias, 0)
954
+
955
+
956
+
957
+
958
+
959
+ class regressor_with_deconv(nn.Module):
960
+ def __init__(self):
961
+ super(regressor_with_deconv, self).__init__()
962
+ self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
963
+ self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
964
+ self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
965
+ self.deconv1 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
966
+ self.deconv2 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
967
+ self.leaky_relu = nn.LeakyReLU()
968
+ self.relu = nn.ReLU()
969
+ self.initialize_weights()
970
+
971
+ def forward(self, x):
972
+ x_ = self.conv1(x)
973
+ x_ = self.leaky_relu(x_)
974
+ x_ = self.deconv1(x_)
975
+ x_ = self.conv2(x_)
976
+ x_ = self.leaky_relu(x_)
977
+ x_ = self.deconv2(x_)
978
+ x_ = self.conv3(x_)
979
+ x_ = self.relu(x_)
980
+ out = x_
981
+ return out
982
+
983
+ def initialize_weights(self):
984
+ for m in self.modules():
985
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
986
+ nn.init.xavier_normal_(m.weight)
987
+ if m.bias is not None:
988
+ nn.init.constant_(m.bias, 0)
989
+
990
+
991
+
models/seg_post_model/cellpose/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .version import version, version_str
models/seg_post_model/cellpose/__main__.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import os, time
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from cellpose import utils, models, io, train
8
+ from .version import version_str
9
+ from cellpose.cli import get_arg_parser
10
+
11
+ try:
12
+ from cellpose.gui import gui3d, gui
13
+ GUI_ENABLED = True
14
+ except ImportError as err:
15
+ GUI_ERROR = err
16
+ GUI_ENABLED = False
17
+ GUI_IMPORT = True
18
+ except Exception as err:
19
+ GUI_ENABLED = False
20
+ GUI_ERROR = err
21
+ GUI_IMPORT = False
22
+ raise
23
+
24
+ import logging
25
+
26
+
27
+ def main():
28
+ """ Run cellpose from command line
29
+ """
30
+
31
+ args = get_arg_parser().parse_args() # this has to be in a separate file for autodoc to work
32
+
33
+ if args.version:
34
+ print(version_str)
35
+ return
36
+
37
+ ######## if no image arguments are provided, run GUI or add model and exit ########
38
+ if len(args.dir) == 0 and len(args.image_path) == 0:
39
+ if args.add_model:
40
+ io.add_model(args.add_model)
41
+ return
42
+ else:
43
+ if not GUI_ENABLED:
44
+ print("GUI ERROR: %s" % GUI_ERROR)
45
+ if GUI_IMPORT:
46
+ print(
47
+ "GUI FAILED: GUI dependencies may not be installed, to install, run"
48
+ )
49
+ print(" pip install 'cellpose[gui]'")
50
+ else:
51
+ if args.Zstack:
52
+ gui3d.run()
53
+ else:
54
+ gui.run()
55
+ return
56
+
57
+ ############################## run cellpose on images ##############################
58
+ if args.verbose:
59
+ from .io import logger_setup
60
+ logger, log_file = logger_setup()
61
+ else:
62
+ print(
63
+ ">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose")
64
+ print("No --verbose => no progress or info printed")
65
+ logger = logging.getLogger(__name__)
66
+
67
+
68
+ # find images
69
+ if len(args.img_filter) > 0:
70
+ image_filter = args.img_filter
71
+ else:
72
+ image_filter = None
73
+
74
+ device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu,
75
+ device=args.gpu_device)
76
+
77
+ if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0":
78
+ pretrained_model = "cpsam"
79
+ logger.warning("training from scratch is disabled, using 'cpsam' model")
80
+ else:
81
+ pretrained_model = args.pretrained_model
82
+
83
+ # Warn users about old arguments from CP3:
84
+ if args.pretrained_model_ortho:
85
+ logger.warning(
86
+ "the '--pretrained_model_ortho' flag is deprecated in v4.0.1+ and no longer used")
87
+ if args.train_size:
88
+ logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
89
+ if args.chan or args.chan2:
90
+ logger.warning('--chan and --chan2 are deprecated, all channels are used by default')
91
+ if args.all_channels:
92
+ logger.warning("the '--all_channels' flag is deprecated in v4.0.1+ and no longer used")
93
+ if args.restore_type:
94
+ logger.warning("the '--restore_type' flag is deprecated in v4.0.1+ and no longer used")
95
+ if args.transformer:
96
+ logger.warning("the '--tranformer' flag is deprecated in v4.0.1+ and no longer used")
97
+ if args.invert:
98
+ logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used")
99
+ if args.chan2_restore:
100
+ logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used")
101
+ if args.diam_mean:
102
+ logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used")
103
+ if args.train_size:
104
+ logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
105
+
106
+ if args.norm_percentile is not None:
107
+ value1, value2 = args.norm_percentile
108
+ normalize = {'percentile': (float(value1), float(value2))}
109
+ else:
110
+ normalize = (not args.no_norm)
111
+
112
+ if args.save_each:
113
+ if not args.save_every:
114
+ raise ValueError("ERROR: --save_each requires --save_every")
115
+
116
+ if len(args.image_path) > 0 and args.train:
117
+ raise ValueError("ERROR: cannot train model with single image input")
118
+
119
+ ## Run evaluation on images
120
+ if not args.train:
121
+ _evaluate_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
122
+
123
+ ## Train a model ##
124
+ else:
125
+ _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
126
+
127
+
128
+ def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize):
129
+ test_dir = None if len(args.test_dir) == 0 else args.test_dir
130
+ images, labels, image_names, train_probs = None, None, None, None
131
+ test_images, test_labels, image_names_test, test_probs = None, None, None, None
132
+ compute_flows = False
133
+ if len(args.file_list) > 0:
134
+ if os.path.exists(args.file_list):
135
+ dat = np.load(args.file_list, allow_pickle=True).item()
136
+ image_names = dat["train_files"]
137
+ image_names_test = dat.get("test_files", None)
138
+ train_probs = dat.get("train_probs", None)
139
+ test_probs = dat.get("test_probs", None)
140
+ compute_flows = dat.get("compute_flows", False)
141
+ load_files = False
142
+ else:
143
+ logger.critical(f"ERROR: {args.file_list} does not exist")
144
+ else:
145
+ output = io.load_train_test_data(args.dir, test_dir, image_filter,
146
+ args.mask_filter,
147
+ args.look_one_level_down)
148
+ images, labels, image_names, test_images, test_labels, image_names_test = output
149
+ load_files = True
150
+
151
+ # initialize model
152
+ model = models.CellposeModel(device=device, pretrained_model=pretrained_model)
153
+
154
+ # train segmentation model
155
+ cpmodel_path = train.train_seg(
156
+ model.net, images, labels, train_files=image_names,
157
+ test_data=test_images, test_labels=test_labels,
158
+ test_files=image_names_test, train_probs=train_probs,
159
+ test_probs=test_probs, compute_flows=compute_flows,
160
+ load_files=load_files, normalize=normalize,
161
+ channel_axis=args.channel_axis,
162
+ learning_rate=args.learning_rate, weight_decay=args.weight_decay,
163
+ SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.train_batch_size,
164
+ min_train_masks=args.min_train_masks,
165
+ nimg_per_epoch=args.nimg_per_epoch,
166
+ nimg_test_per_epoch=args.nimg_test_per_epoch,
167
+ save_path=os.path.realpath(args.dir),
168
+ save_every=args.save_every,
169
+ save_each=args.save_each,
170
+ model_name=args.model_name_out)[0]
171
+ model.pretrained_model = cpmodel_path
172
+ logger.info(">>>> model trained and saved to %s" % cpmodel_path)
173
+ return model
174
+
175
+
176
+ def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, normalize):
177
+ # Check with user if they REALLY mean to run without saving anything
178
+ if not args.train:
179
+ saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt
180
+
181
+ tic = time.time()
182
+ if len(args.dir) > 0:
183
+ image_names = io.get_image_files(
184
+ args.dir, args.mask_filter, imf=imf,
185
+ look_one_level_down=args.look_one_level_down)
186
+ else:
187
+ if os.path.exists(args.image_path):
188
+ image_names = [args.image_path]
189
+ else:
190
+ raise ValueError(f"ERROR: no file found at {args.image_path}")
191
+ nimg = len(image_names)
192
+
193
+ if args.savedir:
194
+ if not os.path.exists(args.savedir):
195
+ raise FileExistsError(f"--savedir {args.savedir} does not exist")
196
+
197
+ logger.info(
198
+ ">>>> running cellpose on %d images using all channels" % nimg)
199
+
200
+ # handle built-in model exceptions
201
+ model = models.CellposeModel(device=device, pretrained_model=pretrained_model,)
202
+
203
+ tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
204
+
205
+ channel_axis = args.channel_axis
206
+ z_axis = args.z_axis
207
+
208
+ for image_name in tqdm(image_names, file=tqdm_out):
209
+ if args.do_3D or args.stitch_threshold > 0.:
210
+ logger.info('loading image as 3D zstack')
211
+ image = io.imread_3D(image_name)
212
+ if channel_axis is None:
213
+ channel_axis = 3
214
+ if z_axis is None:
215
+ z_axis = 0
216
+
217
+ else:
218
+ image = io.imread_2D(image_name)
219
+ out = model.eval(
220
+ image,
221
+ diameter=args.diameter,
222
+ do_3D=args.do_3D,
223
+ augment=args.augment,
224
+ flow_threshold=args.flow_threshold,
225
+ cellprob_threshold=args.cellprob_threshold,
226
+ stitch_threshold=args.stitch_threshold,
227
+ min_size=args.min_size,
228
+ batch_size=args.batch_size,
229
+ bsize=args.bsize,
230
+ resample=not args.no_resample,
231
+ normalize=normalize,
232
+ channel_axis=channel_axis,
233
+ z_axis=z_axis,
234
+ anisotropy=args.anisotropy,
235
+ niter=args.niter,
236
+ flow3D_smooth=args.flow3D_smooth)
237
+ masks, flows = out[:2]
238
+
239
+ if args.exclude_on_edges:
240
+ masks = utils.remove_edge_masks(masks)
241
+ if not args.no_npy:
242
+ io.masks_flows_to_seg(image, masks, flows, image_name,
243
+ imgs_restore=None,
244
+ restore_type=None,
245
+ ratio=1.)
246
+ if saving_something:
247
+ suffix = "_cp_masks"
248
+ if args.output_name is not None:
249
+ # (1) If `savedir` is not defined, then must have a non-zero `suffix`
250
+ if args.savedir is None and len(args.output_name) > 0:
251
+ suffix = args.output_name
252
+ elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir):
253
+ # (2) If `savedir` is defined, and different from `dir` then
254
+ # takes the value passed as a param. (which can be empty string)
255
+ suffix = args.output_name
256
+
257
+ io.save_masks(image, masks, flows, image_name,
258
+ suffix=suffix, png=args.save_png,
259
+ tif=args.save_tif, save_flows=args.save_flows,
260
+ save_outlines=args.save_outlines,
261
+ dir_above=args.dir_above, savedir=args.savedir,
262
+ save_txt=args.save_txt, in_folders=args.in_folders,
263
+ save_mpl=args.save_mpl)
264
+ if args.save_rois:
265
+ io.save_rois(masks, image_name)
266
+ logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))
267
+
268
+ return model
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
models/seg_post_model/cellpose/cli.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden.
3
+ """
4
+
5
+ import argparse
6
+
7
+
8
+ def get_arg_parser():
9
+ """ Parses command line arguments for cellpose main function
10
+
11
+ Note: this function has to be in a separate file to allow autodoc to work for CLI.
12
+ The autodoc_mock_imports in conf.py does not work for sphinx-argparse sometimes,
13
+ see https://github.com/ashb/sphinx-argparse/issues/9#issue-1097057823
14
+ """
15
+
16
+ parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters")
17
+
18
+ # misc settings
19
+ parser.add_argument("--version", action="store_true",
20
+ help="show cellpose version info")
21
+ parser.add_argument(
22
+ "--verbose", action="store_true",
23
+ help="show information about running and settings and save to log")
24
+ parser.add_argument("--Zstack", action="store_true", help="run GUI in 3D mode")
25
+
26
+ # settings for CPU vs GPU
27
+ hardware_args = parser.add_argument_group("Hardware Arguments")
28
+ hardware_args.add_argument("--use_gpu", action="store_true",
29
+ help="use gpu if torch with cuda installed")
30
+ hardware_args.add_argument(
31
+ "--gpu_device", required=False, default="0", type=str,
32
+ help="which gpu device to use, use an integer for torch, or mps for M1")
33
+
34
+ # settings for locating and formatting images
35
+ input_img_args = parser.add_argument_group("Input Image Arguments")
36
+ input_img_args.add_argument("--dir", default=[], type=str,
37
+ help="folder containing data to run or train on.")
38
+ input_img_args.add_argument(
39
+ "--image_path", default=[], type=str, help=
40
+ "if given and --dir not given, run on single image instead of folder (cannot train with this option)"
41
+ )
42
+ input_img_args.add_argument(
43
+ "--look_one_level_down", action="store_true",
44
+ help="run processing on all subdirectories of current folder")
45
+ input_img_args.add_argument("--img_filter", default=[], type=str,
46
+ help="end string for images to run on")
47
+ input_img_args.add_argument(
48
+ "--channel_axis", default=None, type=int,
49
+ help="axis of image which corresponds to image channels")
50
+ input_img_args.add_argument("--z_axis", default=None, type=int,
51
+ help="axis of image which corresponds to Z dimension")
52
+
53
+ # TODO: remove deprecated in future version
54
+ input_img_args.add_argument(
55
+ "--chan", default=0, type=int, help=
56
+ "Deprecated in v4.0.1+, not used. ")
57
+ input_img_args.add_argument(
58
+ "--chan2", default=0, type=int, help=
59
+ 'Deprecated in v4.0.1+, not used. ')
60
+ input_img_args.add_argument("--invert", action="store_true", help=
61
+ 'Deprecated in v4.0.1+, not used. ')
62
+ input_img_args.add_argument(
63
+ "--all_channels", action="store_true", help=
64
+ 'Deprecated in v4.0.1+, not used. ')
65
+
66
+ # model settings
67
+ model_args = parser.add_argument_group("Model Arguments")
68
+ model_args.add_argument("--pretrained_model", required=False, default="cpsam",
69
+ type=str,
70
+ help="model to use for running or starting training")
71
+ model_args.add_argument(
72
+ "--add_model", required=False, default=None, type=str,
73
+ help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
74
+ model_args.add_argument("--pretrained_model_ortho", required=False, default=None,
75
+ type=str,
76
+ help="Deprecated in v4.0.1+, not used. ")
77
+
78
+ # TODO: remove deprecated in future version
79
+ model_args.add_argument("--restore_type", required=False, default=None, type=str, help=
80
+ 'Deprecated in v4.0.1+, not used. ')
81
+ model_args.add_argument("--chan2_restore", action="store_true", help=
82
+ 'Deprecated in v4.0.1+, not used. ')
83
+ model_args.add_argument(
84
+ "--transformer", action="store_true", help=
85
+ "use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)")
86
+
87
+ # algorithm settings
88
+ algorithm_args = parser.add_argument_group("Algorithm Arguments")
89
+ algorithm_args.add_argument("--no_norm", action="store_true",
90
+ help="do not normalize images (normalize=False)")
91
+ algorithm_args.add_argument(
92
+ '--norm_percentile',
93
+ nargs=2, # Require exactly two values
94
+ metavar=('VALUE1', 'VALUE2'),
95
+ help="Provide two float values to set norm_percentile (e.g., --norm_percentile 1 99)"
96
+ )
97
+ algorithm_args.add_argument(
98
+ "--do_3D", action="store_true",
99
+ help="process images as 3D stacks of images (nplanes x nchan x Ly x Lx")
100
+ algorithm_args.add_argument(
101
+ "--diameter", required=False, default=None, type=float, help=
102
+ "use to resize cells to the training diameter (30 pixels)"
103
+ )
104
+ algorithm_args.add_argument(
105
+ "--stitch_threshold", required=False, default=0.0, type=float,
106
+ help="compute masks in 2D then stitch together masks with IoU>0.9 across planes"
107
+ )
108
+ algorithm_args.add_argument(
109
+ "--min_size", required=False, default=15, type=int,
110
+ help="minimum number of pixels per mask, can turn off with -1")
111
+ algorithm_args.add_argument(
112
+ "--flow3D_smooth", required=False, default=0, type=float,
113
+ help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing")
114
+ algorithm_args.add_argument(
115
+ "--flow_threshold", default=0.4, type=float, help=
116
+ "flow error threshold, 0 turns off this optional QC step. Default: %(default)s")
117
+ algorithm_args.add_argument(
118
+ "--cellprob_threshold", default=0, type=float,
119
+ help="cellprob threshold, default is 0, decrease to find more and larger masks")
120
+ algorithm_args.add_argument(
121
+ "--niter", default=0, type=int, help=
122
+ "niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs"
123
+ )
124
+ algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
125
+ help="anisotropy of volume in 3D")
126
+ algorithm_args.add_argument("--exclude_on_edges", action="store_true",
127
+ help="discard masks which touch edges of image")
128
+ algorithm_args.add_argument(
129
+ "--augment", action="store_true",
130
+ help="tiles image with overlapping tiles and flips overlapped regions to augment"
131
+ )
132
+ algorithm_args.add_argument("--batch_size", default=8, type=int,
133
+ help="inference batch size. Default: %(default)s")
134
+
135
+ # TODO: remove deprecated in future version
136
+ algorithm_args.add_argument(
137
+ "--no_resample", action="store_true",
138
+ help="disables flows/cellprob resampling to original image size before computing masks. Using this flag will make more masks more jagged with larger diameter settings.")
139
+ algorithm_args.add_argument(
140
+ "--no_interp", action="store_true",
141
+ help="do not interpolate when running dynamics (was default)")
142
+
143
+ # output settings
144
+ output_args = parser.add_argument_group("Output Arguments")
145
+ output_args.add_argument(
146
+ "--save_png", action="store_true",
147
+ help="save masks as png")
148
+ output_args.add_argument(
149
+ "--save_tif", action="store_true",
150
+ help="save masks as tif")
151
+ output_args.add_argument(
152
+ "--output_name", default=None, type=str,
153
+ help="suffix for saved masks, default is _cp_masks, can be empty if `savedir` used and different of `dir`")
154
+ output_args.add_argument("--no_npy", action="store_true",
155
+ help="suppress saving of npy")
156
+ output_args.add_argument(
157
+ "--savedir", default=None, type=str, help=
158
+ "folder to which segmentation results will be saved (defaults to input image directory)"
159
+ )
160
+ output_args.add_argument(
161
+ "--dir_above", action="store_true", help=
162
+ "save output folders adjacent to image folder instead of inside it (off by default)"
163
+ )
164
+ output_args.add_argument("--in_folders", action="store_true",
165
+ help="flag to save output in folders (off by default)")
166
+ output_args.add_argument(
167
+ "--save_flows", action="store_true", help=
168
+ "whether or not to save RGB images of flows when masks are saved (disabled by default)"
169
+ )
170
+ output_args.add_argument(
171
+ "--save_outlines", action="store_true", help=
172
+ "whether or not to save RGB outline images when masks are saved (disabled by default)"
173
+ )
174
+ output_args.add_argument(
175
+ "--save_rois", action="store_true",
176
+ help="whether or not to save ImageJ compatible ROI archive (disabled by default)"
177
+ )
178
+ output_args.add_argument(
179
+ "--save_txt", action="store_true",
180
+ help="flag to enable txt outlines for ImageJ (disabled by default)")
181
+ output_args.add_argument(
182
+ "--save_mpl", action="store_true",
183
+ help="save a figure of image/mask/flows using matplotlib (disabled by default). "
184
+ "This is slow, especially with large images.")
185
+
186
+ # training settings
187
+ training_args = parser.add_argument_group("Training Arguments")
188
+ training_args.add_argument("--train", action="store_true",
189
+ help="train network using images in dir")
190
+ training_args.add_argument("--test_dir", default=[], type=str,
191
+ help="folder containing test data (optional)")
192
+ training_args.add_argument(
193
+ "--file_list", default=[], type=str, help=
194
+ "path to list of files for training and testing and probabilities for each image (optional)"
195
+ )
196
+ training_args.add_argument(
197
+ "--mask_filter", default="_masks", type=str, help=
198
+ "end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s"
199
+ )
200
+ training_args.add_argument("--learning_rate", default=1e-5, type=float,
201
+ help="learning rate. Default: %(default)s")
202
+ training_args.add_argument("--weight_decay", default=0.1, type=float,
203
+ help="weight decay. Default: %(default)s")
204
+ training_args.add_argument("--n_epochs", default=100, type=int,
205
+ help="number of epochs. Default: %(default)s")
206
+ training_args.add_argument("--train_batch_size", default=1, type=int,
207
+ help="training batch size. Default: %(default)s")
208
+ training_args.add_argument("--bsize", default=256, type=int,
209
+ help="block size for tiles. Default: %(default)s")
210
+ training_args.add_argument(
211
+ "--nimg_per_epoch", default=None, type=int,
212
+ help="number of train images per epoch. Default is to use all train images.")
213
+ training_args.add_argument(
214
+ "--nimg_test_per_epoch", default=None, type=int,
215
+ help="number of test images per epoch. Default is to use all test images.")
216
+ training_args.add_argument(
217
+ "--min_train_masks", default=5, type=int, help=
218
+ "minimum number of masks a training image must have to be used. Default: %(default)s"
219
+ )
220
+ training_args.add_argument("--SGD", default=0, type=int,
221
+ help="Deprecated in v4.0.1+, not used - AdamW used instead. ")
222
+ training_args.add_argument(
223
+ "--save_every", default=100, type=int,
224
+ help="number of epochs to skip between saves. Default: %(default)s")
225
+ training_args.add_argument(
226
+ "--save_each", action="store_true",
227
+ help="wether or not to save each epoch. Must also use --save_every. (default: False)")
228
+ training_args.add_argument(
229
+ "--model_name_out", default=None, type=str,
230
+ help="Name of model to save as, defaults to name describing model architecture. "
231
+ "Model is saved in the folder specified by --dir in models subfolder.")
232
+
233
+ # TODO: remove deprecated in future version
234
+ training_args.add_argument(
235
+ "--diam_mean", default=30., type=float, help=
236
+ 'Deprecated in v4.0.1+, not used. ')
237
+ training_args.add_argument("--train_size", action="store_true", help=
238
+ 'Deprecated in v4.0.1+, not used. ')
239
+
240
+ return parser
models/seg_post_model/cellpose/core.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import logging
5
+ import numpy as np
6
+ from tqdm import trange
7
+ from . import transforms, utils
8
+
9
+ import torch
10
+
11
+ TORCH_ENABLED = True
12
+
13
+ core_logger = logging.getLogger(__name__)
14
+ tqdm_out = utils.TqdmToLogger(core_logger, level=logging.INFO)
15
+
16
+
17
+ def use_gpu(gpu_number=0, use_torch=True):
18
+ """
19
+ Check if GPU is available for use.
20
+
21
+ Args:
22
+ gpu_number (int): The index of the GPU to be used. Default is 0.
23
+ use_torch (bool): Whether to use PyTorch for GPU check. Default is True.
24
+
25
+ Returns:
26
+ bool: True if GPU is available, False otherwise.
27
+
28
+ Raises:
29
+ ValueError: If use_torch is False, as cellpose only runs with PyTorch now.
30
+ """
31
+ if use_torch:
32
+ return _use_gpu_torch(gpu_number)
33
+ else:
34
+ raise ValueError("cellpose only runs with PyTorch now")
35
+
36
+
37
+ def _use_gpu_torch(gpu_number=0):
38
+ """
39
+ Checks if CUDA or MPS is available and working with PyTorch.
40
+
41
+ Args:
42
+ gpu_number (int): The GPU device number to use (default is 0).
43
+
44
+ Returns:
45
+ bool: True if CUDA or MPS is available and working, False otherwise.
46
+ """
47
+ try:
48
+ device = torch.device("cuda:" + str(gpu_number))
49
+ _ = torch.zeros((1,1)).to(device)
50
+ core_logger.info("** TORCH CUDA version installed and working. **")
51
+ return True
52
+ except:
53
+ pass
54
+ try:
55
+ device = torch.device('mps:' + str(gpu_number))
56
+ _ = torch.zeros((1,1)).to(device)
57
+ core_logger.info('** TORCH MPS version installed and working. **')
58
+ return True
59
+ except:
60
+ core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.')
61
+ return False
62
+
63
+
64
+ def assign_device(use_torch=True, gpu=False, device=0):
65
+ """
66
+ Assigns the device (CPU or GPU or mps) to be used for computation.
67
+
68
+ Args:
69
+ use_torch (bool, optional): Whether to use torch for GPU detection. Defaults to True.
70
+ gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
71
+ device (int or str, optional): The device index or name to be used. Defaults to 0.
72
+
73
+ Returns:
74
+ torch.device, bool (True if GPU is used, False otherwise)
75
+ """
76
+
77
+ if isinstance(device, str):
78
+ if device != "mps" or not(gpu and torch.backends.mps.is_available()):
79
+ device = int(device)
80
+ if gpu and use_gpu(use_torch=True):
81
+ try:
82
+ if torch.cuda.is_available():
83
+ device = torch.device(f'cuda:{device}')
84
+ core_logger.info(">>>> using GPU (CUDA)")
85
+ gpu = True
86
+ cpu = False
87
+ except:
88
+ gpu = False
89
+ cpu = True
90
+ try:
91
+ if torch.backends.mps.is_available():
92
+ device = torch.device('mps')
93
+ core_logger.info(">>>> using GPU (MPS)")
94
+ gpu = True
95
+ cpu = False
96
+ except:
97
+ gpu = False
98
+ cpu = True
99
+ else:
100
+ device = torch.device('cpu')
101
+ core_logger.info('>>>> using CPU')
102
+ gpu = False
103
+ cpu = True
104
+
105
+ if cpu:
106
+ device = torch.device("cpu")
107
+ core_logger.info(">>>> using CPU")
108
+ gpu = False
109
+ return device, gpu
110
+
111
+
112
+ def _to_device(x, device, dtype=torch.float32):
113
+ """
114
+ Converts the input tensor or numpy array to the specified device.
115
+
116
+ Args:
117
+ x (torch.Tensor or numpy.ndarray): The input tensor or numpy array.
118
+ device (torch.device): The target device.
119
+
120
+ Returns:
121
+ torch.Tensor: The converted tensor on the specified device.
122
+ """
123
+ if not isinstance(x, torch.Tensor):
124
+ X = torch.from_numpy(x).to(device, dtype=dtype)
125
+ return X
126
+ else:
127
+ return x
128
+
129
+
130
+ def _from_device(X):
131
+ """
132
+ Converts a PyTorch tensor from the device to a NumPy array on the CPU.
133
+
134
+ Args:
135
+ X (torch.Tensor): The input PyTorch tensor.
136
+
137
+ Returns:
138
+ numpy.ndarray: The converted NumPy array.
139
+ """
140
+ # The cast is so numpy conversion always works
141
+ x = X.detach().cpu().to(torch.float32).numpy()
142
+ return x
143
+
144
+
145
+ def _forward(net, x, feat=None):
146
+ """Converts images to torch tensors, runs the network model, and returns numpy arrays.
147
+
148
+ Args:
149
+ net (torch.nn.Module): The network model.
150
+ x (numpy.ndarray): The input images.
151
+
152
+ Returns:
153
+ Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
154
+ """
155
+ X = _to_device(x, device=net.device, dtype=net.dtype)
156
+ if feat is not None:
157
+ feat = _to_device(feat, device=net.device, dtype=net.dtype)
158
+ net.eval()
159
+ with torch.no_grad():
160
+ y, style = net(X, feat=feat)[:2]
161
+ del X
162
+ y = _from_device(y)
163
+ style = _from_device(style)
164
+ return y, style
165
+
166
+
167
+ def run_net(net, imgi, feat=None, batch_size=8, augment=False, tile_overlap=0.1, bsize=224,
168
+ rsz=None):
169
+ """
170
+ Run network on stack of images.
171
+
172
+ (faster if augment is False)
173
+
174
+ Args:
175
+ net (class): cellpose network (model.net)
176
+ imgi (np.ndarray): The input image or stack of images of size [Lz x Ly x Lx x nchan].
177
+ batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
178
+ rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
179
+ augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
180
+ tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
181
+ bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
182
+
183
+ Returns:
184
+ Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
185
+ y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
186
+ style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
187
+ """
188
+ # run network
189
+ Lz, Ly0, Lx0, nchan = imgi.shape
190
+ if rsz is not None:
191
+ if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
192
+ rsz = [rsz, rsz]
193
+ Lyr, Lxr = int(Ly0 * rsz[0]), int(Lx0 * rsz[1])
194
+ else:
195
+ Lyr, Lxr = Ly0, Lx0 # 512, 512
196
+
197
+ ly, lx = bsize, bsize # 256, 256
198
+ ypad1, ypad2, xpad1, xpad2 = transforms.get_pad_yx(Lyr, Lxr, min_size=(bsize, bsize)) # 8
199
+ Ly, Lx = Lyr + ypad1 + ypad2, Lxr + xpad1 + xpad2 # 528, 528
200
+ pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]])
201
+
202
+ if augment:
203
+ ny = max(2, int(np.ceil(2. * Ly / bsize)))
204
+ nx = max(2, int(np.ceil(2. * Lx / bsize)))
205
+ else:
206
+ ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize)) # 3
207
+ nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize)) # 3
208
+
209
+
210
+ # run multiple slices at the same time
211
+ ntiles = ny * nx
212
+ nimgs = max(1, batch_size // ntiles) # number of imgs to run in the same batch, 1
213
+ niter = int(np.ceil(Lz / nimgs)) # 1
214
+ ziterator = (trange(niter, file=tqdm_out, mininterval=30)
215
+ if niter > 10 or Lz > 1 else range(niter))
216
+ for k in ziterator:
217
+ inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs))
218
+ IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 3, 256, 256
219
+ if feat is not None:
220
+ FEATa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 256
221
+ else:
222
+ FEATa = None
223
+ for i, b in enumerate(inds):
224
+ # pad image for net so Ly and Lx are divisible by 4
225
+ imgb = transforms.resize_image(imgi[b], rsz=rsz) if rsz is not None else imgi[b].copy()
226
+ imgb = np.pad(imgb.transpose(2,0,1), pads, mode="constant") # 3, 528, 528
227
+
228
+ IMG, ysub, xsub, Lyt, Lxt = transforms.make_tiles(
229
+ imgb, bsize=bsize, augment=augment,
230
+ tile_overlap=tile_overlap) # IMG: 3, 3, 3, 256, 256
231
+ IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG,
232
+ (ny * nx, nchan, ly, lx))
233
+ if feat is not None:
234
+ featb = transforms.resize_image(feat[b], rsz=rsz) if rsz is not None else feat[b].copy()
235
+ featb = np.pad(featb.transpose(2,0,1), pads, mode="constant")
236
+ FEAT, ysub, xsub, Lyt, Lxt = transforms.make_tiles(
237
+ featb, bsize=bsize, augment=augment,
238
+ tile_overlap=tile_overlap)
239
+ FEATa[i * ntiles : (i+1) * ntiles] = np.reshape(FEAT,
240
+ (ny * nx, nchan, ly, lx))
241
+
242
+ # run network
243
+ for j in range(0, IMGa.shape[0], batch_size):
244
+ bslc = slice(j, min(j + batch_size, IMGa.shape[0]))
245
+ ya0, stylea0 = _forward(net, IMGa[bslc], feat=FEATa[bslc] if FEATa is not None else None)
246
+ if j == 0:
247
+ nout = ya0.shape[1]
248
+ ya = np.zeros((IMGa.shape[0], nout, ly, lx), "float32")
249
+ stylea = np.zeros((IMGa.shape[0], 256), "float32")
250
+ ya[bslc] = ya0
251
+ stylea[bslc] = stylea0
252
+
253
+ # average tiles
254
+ for i, b in enumerate(inds):
255
+ if i==0 and k==0:
256
+ yf = np.zeros((Lz, nout, Ly, Lx), "float32")
257
+ styles = np.zeros((Lz, 256), "float32")
258
+ y = ya[i * ntiles : (i + 1) * ntiles]
259
+ if augment:
260
+ y = np.reshape(y, (ny, nx, 3, ly, lx))
261
+ y = transforms.unaugment_tiles(y)
262
+ y = np.reshape(y, (-1, 3, ly, lx))
263
+ yfi = transforms.average_tiles(y, ysub, xsub, Lyt, Lxt)
264
+ yf[b] = yfi[:, :imgb.shape[-2], :imgb.shape[-1]]
265
+ stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0)
266
+ stylei /= (stylei**2).sum()**0.5
267
+ styles[b] = stylei
268
+ # slices from padding
269
+ yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2]
270
+ yf = yf.transpose(0,2,3,1)
271
+ return yf, np.array(styles)
272
+
273
+
274
+ def run_3D(net, imgs, batch_size=8, augment=False,
275
+ tile_overlap=0.1, bsize=224, net_ortho=None,
276
+ progress=None):
277
+ """
278
+ Run network on image z-stack.
279
+
280
+ (faster if augment is False)
281
+
282
+ Args:
283
+ imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan].
284
+ batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
285
+ rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
286
+ anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
287
+ augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
288
+ tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
289
+ bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
290
+ net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None.
291
+ progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
292
+
293
+ Returns:
294
+ Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
295
+ y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability.
296
+ style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
297
+ """
298
+ sstr = ["YX", "ZY", "ZX"]
299
+ pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)]
300
+ ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)]
301
+ cp = [(1, 2), (0, 2), (0, 1)]
302
+ cpy = [(0, 1), (0, 1), (0, 1)]
303
+ shape = imgs.shape[:-1]
304
+ yf = np.zeros((*shape, 4), "float32")
305
+ for p in range(3):
306
+ xsl = imgs.transpose(pm[p])
307
+ # per image
308
+ core_logger.info("running %s: %d planes of size (%d, %d)" %
309
+ (sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]]))
310
+ y, style = run_net(net,
311
+ xsl, batch_size=batch_size, augment=augment,
312
+ bsize=bsize, tile_overlap=tile_overlap,
313
+ rsz=None)
314
+ yf[..., -1] += y[..., -1].transpose(ipm[p])
315
+ for j in range(2):
316
+ yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p])
317
+ y = None; del y
318
+
319
+ if progress is not None:
320
+ progress.setValue(25 + 15 * p)
321
+
322
+ return yf, style
models/seg_post_model/cellpose/denoise.py ADDED
@@ -0,0 +1,1474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import os, time, datetime
5
+ import numpy as np
6
+ from scipy.stats import mode
7
+ import cv2
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn.functional import conv2d, interpolate
11
+ from tqdm import trange
12
+ from pathlib import Path
13
+
14
+ import logging
15
+
16
+ denoise_logger = logging.getLogger(__name__)
17
+
18
+ from cellpose import transforms, utils, io
19
+ from cellpose.core import run_net
20
+ from cellpose.models import CellposeModel, model_path, normalize_default, assign_device
21
+
22
+ MODEL_NAMES = []
23
+ for ctype in ["cyto3", "cyto2", "nuclei"]:
24
+ for ntype in ["denoise", "deblur", "upsample", "oneclick"]:
25
+ MODEL_NAMES.append(f"{ntype}_{ctype}")
26
+ if ctype != "cyto3":
27
+ for ltype in ["per", "seg", "rec"]:
28
+ MODEL_NAMES.append(f"{ntype}_{ltype}_{ctype}")
29
+ if ctype != "cyto3":
30
+ MODEL_NAMES.append(f"aniso_{ctype}")
31
+
32
+ criterion = nn.MSELoss(reduction="mean")
33
+ criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
34
+
35
+
36
+ def deterministic(seed=0):
37
+ """ set random seeds to create test data """
38
+ import random
39
+ torch.manual_seed(seed)
40
+ torch.cuda.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
42
+ np.random.seed(seed) # Numpy module.
43
+ random.seed(seed) # Python random module.
44
+ torch.manual_seed(seed)
45
+ torch.backends.cudnn.benchmark = False
46
+ torch.backends.cudnn.deterministic = True
47
+
48
+
49
+ def loss_fn_rec(lbl, y):
50
+ """ loss function between true labels lbl and prediction y """
51
+ loss = 80. * criterion(y, lbl)
52
+ return loss
53
+
54
+
55
+ def loss_fn_seg(lbl, y):
56
+ """ loss function between true labels lbl and prediction y """
57
+ veci = 5. * lbl[:, 1:]
58
+ lbl = (lbl[:, 0] > .5).float()
59
+ loss = criterion(y[:, :2], veci)
60
+ loss /= 2.
61
+ loss2 = criterion2(y[:, 2], lbl)
62
+ loss = loss + loss2
63
+ return loss
64
+
65
+
66
+ def get_sigma(Tdown):
67
+ """ Calculates the correlation matrices across channels for the perceptual loss.
68
+
69
+ Args:
70
+ Tdown (list): List of tensors output by each downsampling block of network.
71
+
72
+ Returns:
73
+ list: List of correlations for each input tensor.
74
+ """
75
+ Tnorm = [x - x.mean((-2, -1), keepdim=True) for x in Tdown]
76
+ Tnorm = [x / x.std((-2, -1), keepdim=True) for x in Tnorm]
77
+ Sigma = [
78
+ torch.einsum("bnxy, bmxy -> bnm", x, x) / (x.shape[-2] * x.shape[-1])
79
+ for x in Tnorm
80
+ ]
81
+ return Sigma
82
+
83
+
84
+ def imstats(X, net1):
85
+ """
86
+ Calculates the image correlation matrices for the perceptual loss.
87
+
88
+ Args:
89
+ X (torch.Tensor): Input image tensor.
90
+ net1: Cellpose net.
91
+
92
+ Returns:
93
+ list: A list of tensors of correlation matrices.
94
+ """
95
+ _, _, Tdown = net1(X)
96
+ Sigma = get_sigma(Tdown)
97
+ Sigma = [x.detach() for x in Sigma]
98
+ return Sigma
99
+
100
+
101
+ def loss_fn_per(img, net1, yl):
102
+ """
103
+ Calculates the perceptual loss function for image restoration.
104
+
105
+ Args:
106
+ img (torch.Tensor): Input image tensor (noisy/blurry/downsampled).
107
+ net1 (torch.nn.Module): Perceptual loss net (Cellpose segmentation net).
108
+ yl (torch.Tensor): Clean image tensor.
109
+
110
+ Returns:
111
+ torch.Tensor: Mean perceptual loss.
112
+ """
113
+ Sigma = imstats(img, net1)
114
+ sd = [x.std((1, 2)) + 1e-6 for x in Sigma]
115
+ Sigma_test = get_sigma(yl)
116
+ losses = torch.zeros(len(Sigma[0]), device=img.device)
117
+ for k in range(len(Sigma)):
118
+ losses = losses + (((Sigma_test[k] - Sigma[k])**2).mean((1, 2)) / sd[k]**2)
119
+ return losses.mean()
120
+
121
+
122
+ def test_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
123
+ """
124
+ Calculates the test loss for image restoration tasks.
125
+
126
+ Args:
127
+ net0 (torch.nn.Module): The image restoration network.
128
+ X (torch.Tensor): The input image tensor.
129
+ net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
130
+ img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
131
+ lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
132
+ lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
133
+
134
+ Returns:
135
+ tuple: A tuple containing the total loss and the perceptual loss.
136
+ """
137
+ net0.eval()
138
+ if net1 is not None:
139
+ net1.eval()
140
+ loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
141
+
142
+ with torch.no_grad():
143
+ img_dn = net0(X)[0]
144
+ if lam[2] > 0.:
145
+ loss += lam[2] * loss_fn_rec(img, img_dn)
146
+ if lam[1] > 0. or lam[0] > 0.:
147
+ y, _, ydown = net1(img_dn)
148
+ if lam[1] > 0.:
149
+ loss += lam[1] * loss_fn_seg(lbl, y)
150
+ if lam[0] > 0.:
151
+ loss_per = loss_fn_per(img, net1, ydown)
152
+ loss += lam[0] * loss_per
153
+ return loss, loss_per
154
+
155
+
156
+ def train_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
157
+ """
158
+ Calculates the train loss for image restoration tasks.
159
+
160
+ Args:
161
+ net0 (torch.nn.Module): The image restoration network.
162
+ X (torch.Tensor): The input image tensor.
163
+ net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
164
+ img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
165
+ lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
166
+ lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
167
+
168
+ Returns:
169
+ tuple: A tuple containing the total loss and the perceptual loss.
170
+ """
171
+ net0.train()
172
+ if net1 is not None:
173
+ net1.eval()
174
+ loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
175
+
176
+ img_dn = net0(X)[0]
177
+ if lam[2] > 0.:
178
+ loss += lam[2] * loss_fn_rec(img, img_dn)
179
+ if lam[1] > 0. or lam[0] > 0.:
180
+ y, _, ydown = net1(img_dn)
181
+ if lam[1] > 0.:
182
+ loss += lam[1] * loss_fn_seg(lbl, y)
183
+ if lam[0] > 0.:
184
+ loss_per = loss_fn_per(img, net1, ydown)
185
+ loss += lam[0] * loss_per
186
+ return loss, loss_per
187
+
188
+
189
+ def img_norm(imgi):
190
+ """
191
+ Normalizes the input image by subtracting the 1st percentile and dividing by the difference between the 99th and 1st percentiles.
192
+
193
+ Args:
194
+ imgi (torch.Tensor): Input image tensor.
195
+
196
+ Returns:
197
+ torch.Tensor: Normalized image tensor.
198
+ """
199
+ shape = imgi.shape
200
+ imgi = imgi.reshape(imgi.shape[0], imgi.shape[1], -1)
201
+ perc = torch.quantile(imgi, torch.tensor([0.01, 0.99], device=imgi.device), dim=-1,
202
+ keepdim=True)
203
+ for k in range(imgi.shape[1]):
204
+ hask = (perc[1, :, k, 0] - perc[0, :, k, 0]) > 1e-3
205
+ imgi[hask, k] -= perc[0, hask, k]
206
+ imgi[hask, k] /= (perc[1, hask, k] - perc[0, hask, k])
207
+ imgi = imgi.reshape(shape)
208
+ return imgi
209
+
210
+
211
+ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7,
212
+ ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None,
213
+ ds=None, uniform_blur=False, partial_blur=False):
214
+ """Adds noise to the input image.
215
+
216
+ Args:
217
+ lbl (torch.Tensor): The input image tensor of shape (nimg, nchan, Ly, Lx).
218
+ alpha (float, optional): The shape parameter of the gamma distribution used for generating poisson noise. Defaults to 4.
219
+ beta (float, optional): The rate parameter of the gamma distribution used for generating poisson noise. Defaults to 0.7.
220
+ poisson (float, optional): The probability of adding poisson noise to the image. Defaults to 0.7.
221
+ blur (float, optional): The probability of adding gaussian blur to the image. Defaults to 0.7.
222
+ gblur (float, optional): The scale factor for the gaussian blur. Defaults to 1.0.
223
+ downsample (float, optional): The probability of downsampling the image. Defaults to 0.7.
224
+ ds_max (int, optional): The maximum downsampling factor. Defaults to 7.
225
+ diams (torch.Tensor, optional): The diameter of the objects in the image. Defaults to None.
226
+ pscale (torch.Tensor, optional): The scale factor for the poisson noise, instead of sampling. Defaults to None.
227
+ iso (bool, optional): Whether to use isotropic gaussian blur. Defaults to True.
228
+ sigma0 (torch.Tensor, optional): The standard deviation of the gaussian filter for the Y axis, instead of sampling. Defaults to None.
229
+ sigma1 (torch.Tensor, optional): The standard deviation of the gaussian filter for the X axis, instead of sampling. Defaults to None.
230
+ ds (torch.Tensor, optional): The downsampling factor for each image, instead of sampling. Defaults to None.
231
+
232
+ Returns:
233
+ torch.Tensor: The noisy image tensor of the same shape as the input image.
234
+ """
235
+ device = lbl.device
236
+ imgi = torch.zeros_like(lbl)
237
+ Ly, Lx = lbl.shape[-2:]
238
+
239
+ diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device)
240
+ #ds0 = 1 if ds is None else ds.item()
241
+ ds = ds * torch.ones(
242
+ (len(lbl),), device=device, dtype=torch.long) if ds is not None else ds
243
+
244
+ # downsample
245
+ ii = []
246
+ idownsample = np.random.rand(len(lbl)) < downsample
247
+ if (ds is None and idownsample.sum() > 0.) or not iso:
248
+ ds = torch.ones(len(lbl), dtype=torch.long, device=device)
249
+ ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),),
250
+ device=device)
251
+ ii = torch.nonzero(ds > 1).flatten()
252
+ elif ds is not None and (ds > 1).sum():
253
+ ii = torch.nonzero(ds > 1).flatten()
254
+
255
+ # add gaussian blur
256
+ iblur = torch.rand(len(lbl), device=device) < blur
257
+ iblur[ii] = True
258
+ if iblur.sum() > 0:
259
+ if sigma0 is None:
260
+ if uniform_blur and iso:
261
+ xr = torch.rand(len(lbl), device=device)
262
+ if len(ii) > 0:
263
+ xr[ii] = ds[ii].float() / 2. / gblur
264
+ sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur])
265
+ sigma1 = sigma0.clone()
266
+ elif not iso:
267
+ xr = torch.rand(len(lbl), device=device)
268
+ if len(ii) > 0:
269
+ xr[ii] = (ds[ii].float()) / gblur
270
+ xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35
271
+ xr[ii] = torch.clip(xr[ii], 0.05, 1.5)
272
+ sigma0 = diams[iblur] / 30. * gblur * xr[iblur]
273
+ sigma1 = sigma0.clone() / 10.
274
+ else:
275
+ xrand = np.random.exponential(1, size=iblur.sum())
276
+ xrand = np.clip(xrand * 0.5, 0.1, 1.0)
277
+ xrand *= gblur
278
+ sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to(
279
+ device)
280
+ sigma1 = sigma0.clone()
281
+ else:
282
+ sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device)
283
+ sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device)
284
+
285
+ # create gaussian filter
286
+ xr = max(8, sigma0.max().long() * 2)
287
+ gfilt0 = torch.exp(-torch.arange(-xr + 1, xr, device=device)**2 /
288
+ (2 * sigma0.unsqueeze(-1)**2))
289
+ gfilt0 /= gfilt0.sum(axis=-1, keepdims=True)
290
+ gfilt1 = torch.zeros_like(gfilt0)
291
+ gfilt1[sigma1 == sigma0] = gfilt0[sigma1 == sigma0]
292
+ gfilt1[sigma1 != sigma0] = torch.exp(
293
+ -torch.arange(-xr + 1, xr, device=device)**2 /
294
+ (2 * sigma1[sigma1 != sigma0].unsqueeze(-1)**2))
295
+ gfilt1[sigma1 == 0] = 0.
296
+ gfilt1[sigma1 == 0, xr] = 1.
297
+ gfilt1 /= gfilt1.sum(axis=-1, keepdims=True)
298
+ gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1)
299
+ gfilt /= gfilt.sum(axis=(1, 2), keepdims=True)
300
+
301
+ lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1),
302
+ padding=gfilt.shape[-1] // 2,
303
+ groups=gfilt.shape[0]).transpose(1, 0)
304
+ if partial_blur:
305
+ #yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100)
306
+ imgi[iblur] = lbl[iblur].clone()
307
+ Lxc = int(Lx * 0.85)
308
+ ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32),
309
+ torch.arange(0, Lxc, dtype=torch.float32),
310
+ indexing="ij")
311
+ mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2))
312
+ mask -= mask.min()
313
+ mask /= mask.max()
314
+ lbl_blur_crop = lbl_blur[:, :, :, :Lxc]
315
+ imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask +
316
+ (1-mask) * imgi[iblur, :, :, :Lxc])
317
+ else:
318
+ imgi[iblur] = lbl_blur
319
+
320
+ imgi[~iblur] = lbl[~iblur]
321
+
322
+ # apply downsample
323
+ for k in ii:
324
+ i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]]
325
+ imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear")
326
+
327
+ # add poisson noise
328
+ ipoisson = np.random.rand(len(lbl)) < poisson
329
+ if ipoisson.sum() > 0:
330
+ if pscale is None:
331
+ pscale = torch.zeros(len(lbl))
332
+ m = torch.distributions.gamma.Gamma(alpha, beta)
333
+ pscale = torch.clamp(m.rsample(sample_shape=(ipoisson.sum(),)), 1.)
334
+ #pscale = torch.clamp(20 * (torch.rand(size=(len(lbl),), device=lbl.device)), 1.5)
335
+ pscale = pscale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
336
+ else:
337
+ pscale = pscale * torch.ones((ipoisson.sum(), 1, 1, 1), device=device)
338
+ imgi[ipoisson] = torch.poisson(pscale * imgi[ipoisson])
339
+ imgi[~ipoisson] = imgi[~ipoisson]
340
+
341
+ # renormalize
342
+ imgi = img_norm(imgi)
343
+
344
+ return imgi
345
+
346
+
347
+ def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7,
348
+ downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30,
349
+ ds_max=7, uniform_blur=False, iso=True, rotate=True,
350
+ device=torch.device("cuda"), xy=(224, 224),
351
+ nchan_noise=1, keep_raw=True):
352
+ """
353
+ Applies random rotation, resizing, and noise to the input data.
354
+
355
+ Args:
356
+ data (numpy.ndarray): The input data.
357
+ labels (numpy.ndarray, optional): The flow and cellprob labels associated with the data. Defaults to None.
358
+ diams (float, optional): The diameter of the objects. Defaults to None.
359
+ poisson (float, optional): The Poisson noise probability. Defaults to 0.7.
360
+ blur (float, optional): The blur probability. Defaults to 0.7.
361
+ downsample (float, optional): The downsample probability. Defaults to 0.0.
362
+ beta (float, optional): The beta value for the poisson noise distribution. Defaults to 0.7.
363
+ gblur (float, optional): The Gaussian blur level. Defaults to 1.0.
364
+ diam_mean (float, optional): The mean diameter. Defaults to 30.
365
+ ds_max (int, optional): The maximum downsample value. Defaults to 7.
366
+ iso (bool, optional): Whether to apply isotropic augmentation. Defaults to True.
367
+ rotate (bool, optional): Whether to apply rotation augmentation. Defaults to True.
368
+ device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
369
+ xy (tuple, optional): The size of the output image. Defaults to (224, 224).
370
+ nchan_noise (int, optional): The number of channels to add noise to. Defaults to 1.
371
+ keep_raw (bool, optional): Whether to keep the raw image. Defaults to True.
372
+
373
+ Returns:
374
+ torch.Tensor: The augmented image and augmented noisy/blurry/downsampled version of image.
375
+ torch.Tensor: The augmented labels.
376
+ float: The scale factor applied to the image.
377
+ """
378
+ if device == None:
379
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
380
+
381
+ diams = 30 if diams is None else diams
382
+ random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1))
383
+ random_rsc = diams / random_diam #/ random_diam
384
+ #rsc /= random_scale
385
+ xy0 = (340, 340)
386
+ nchan = data[0].shape[0]
387
+ data_new = np.zeros((len(data), (1 + keep_raw) * nchan, xy0[0], xy0[1]), "float32")
388
+ labels_new = np.zeros((len(data), 3, xy0[0], xy0[1]), "float32")
389
+ for i in range(
390
+ len(data)): #, (sc, img, lbl) in enumerate(zip(random_rsc, data, labels)):
391
+ sc = random_rsc[i]
392
+ img = data[i]
393
+ lbl = labels[i] if labels is not None else None
394
+ # create affine transform to resize
395
+ Ly, Lx = img.shape[-2:]
396
+ dxy = np.maximum(0, np.array([Lx / sc - xy0[1], Ly / sc - xy0[0]]))
397
+ dxy = (np.random.rand(2,) - .5) * dxy
398
+ cc = np.array([Lx / 2, Ly / 2])
399
+ cc1 = cc - np.array([Lx - xy0[1], Ly - xy0[0]]) / 2 + dxy
400
+ pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
401
+ pts2 = np.float32(
402
+ [cc1, cc1 + np.array([1, 0]) / sc, cc1 + np.array([0, 1]) / sc])
403
+ M = cv2.getAffineTransform(pts1, pts2)
404
+
405
+ # apply to image
406
+ for c in range(nchan):
407
+ img_rsz = cv2.warpAffine(img[c], M, xy0, flags=cv2.INTER_LINEAR)
408
+ #img_noise = add_noise(torch.from_numpy(img_rsz).to(device).unsqueeze(0)).cpu().numpy().squeeze(0)
409
+ data_new[i, c] = img_rsz
410
+ if keep_raw:
411
+ data_new[i, c + nchan] = img_rsz
412
+
413
+ if lbl is not None:
414
+ # apply to labels
415
+ labels_new[i, 0] = cv2.warpAffine(lbl[0], M, xy0, flags=cv2.INTER_NEAREST)
416
+ labels_new[i, 1] = cv2.warpAffine(lbl[1], M, xy0, flags=cv2.INTER_LINEAR)
417
+ labels_new[i, 2] = cv2.warpAffine(lbl[2], M, xy0, flags=cv2.INTER_LINEAR)
418
+
419
+ rsc = random_diam / diam_mean
420
+
421
+ # add noise before augmentations
422
+ img = torch.from_numpy(data_new).to(device)
423
+ img = torch.clamp(img, 0.)
424
+ # just add noise to cyto if nchan_noise=1
425
+ img[:, :nchan_noise] = add_noise(
426
+ img[:, :nchan_noise], poisson=poisson, blur=blur, ds_max=ds_max, iso=iso,
427
+ downsample=downsample, beta=beta, gblur=gblur,
428
+ diams=torch.from_numpy(random_diam).to(device).float())
429
+ # img -= img.mean(dim=(-2,-1), keepdim=True)
430
+ # img /= img.std(dim=(-2,-1), keepdim=True) + 1e-3
431
+ img = img.cpu().numpy()
432
+
433
+ # augmentations
434
+ img, lbl, scale = transforms.random_rotate_and_resize(
435
+ img,
436
+ Y=labels_new,
437
+ xy=xy,
438
+ rotate=False if not iso else rotate,
439
+ #(iso and downsample==0),
440
+ rescale=rsc,
441
+ scale_range=0.5)
442
+ img = torch.from_numpy(img).to(device)
443
+ lbl = torch.from_numpy(lbl).to(device)
444
+
445
+ return img, lbl, scale
446
+
447
+
448
+ def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
449
+ """
450
+ Creates a Cellpose network with a single input channel.
451
+
452
+ Args:
453
+ device (str): The device to run the network on.
454
+ model_type (str, optional): The type of Cellpose model to use. Defaults to "cyto2".
455
+ pretrained_model (str, optional): The path to a pretrained model file. Defaults to None.
456
+
457
+ Returns:
458
+ torch.nn.Module: The Cellpose network with a single input channel.
459
+ """
460
+ if pretrained_model is not None and not os.path.exists(pretrained_model):
461
+ model_type = pretrained_model
462
+ pretrained_model = None
463
+ nbase = [32, 64, 128, 256]
464
+ nchan = 1
465
+ net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device)
466
+ filename = model_path(model_type,
467
+ 0) if pretrained_model is None else pretrained_model
468
+ weights = torch.load(filename, weights_only=True)
469
+ zp = 0
470
+ print(filename)
471
+ for name in net1.state_dict():
472
+ if ("res_down_0.conv.conv_0" not in name and
473
+ #"output" not in name and
474
+ "res_down_0.proj" not in name and name != "diam_mean" and
475
+ name != "diam_labels"):
476
+ net1.state_dict()[name].copy_(weights[name])
477
+ elif "res_down_0" in name:
478
+ if len(weights[name].shape) > 0:
479
+ new_weight = torch.zeros_like(net1.state_dict()[name])
480
+ if weights[name].shape[0] == 2:
481
+ new_weight[:] = weights[name][0]
482
+ elif len(weights[name].shape) > 1 and weights[name].shape[1] == 2:
483
+ new_weight[:, zp] = weights[name][:, 0]
484
+ else:
485
+ new_weight = weights[name]
486
+ else:
487
+ new_weight = weights[name]
488
+ net1.state_dict()[name].copy_(new_weight)
489
+ return net1
490
+
491
+
492
+ class CellposeDenoiseModel():
493
+ """ model to run Cellpose and Image restoration """
494
+
495
+ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
496
+ restore_type="denoise_cyto3", nchan=2,
497
+ chan2_restore=False, device=None):
498
+
499
+ self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore,
500
+ device=device)
501
+ self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan,
502
+ pretrained_model=pretrained_model, device=device)
503
+
504
+ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
505
+ normalize=True, rescale=None, diameter=None, tile_overlap=0.1,
506
+ augment=False, resample=True, invert=False, flow_threshold=0.4,
507
+ cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0,
508
+ min_size=15, niter=None, interp=True, bsize=224, flow3D_smooth=0):
509
+ """
510
+ Restore array or list of images using the image restoration model, and then segment.
511
+
512
+ Args:
513
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
514
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
515
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
516
+ channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
517
+ First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
518
+ Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
519
+ For instance, to segment grayscale images, input [0,0]. To segment images with cells
520
+ in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
521
+ image with cells in green and nuclei in blue, input [[0,0], [2,3]].
522
+ Defaults to None.
523
+ channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
524
+ if None, channels dimension is attempted to be automatically determined. Defaults to None.
525
+ z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
526
+ if None, z dimension is attempted to be automatically determined. Defaults to None.
527
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
528
+ can also pass dictionary of parameters (all keys are optional, default values shown):
529
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
530
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
531
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
532
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
533
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
534
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
535
+ Defaults to True.
536
+ rescale (float, optional): resize factor for each image, if None, set to 1.0;
537
+ (only used if diameter is None). Defaults to None.
538
+ diameter (float, optional): diameter for each image,
539
+ if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
540
+ tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
541
+ augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
542
+ resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
543
+ invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
544
+ flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
545
+ cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
546
+ do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False.
547
+ anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
548
+ stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
549
+ min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
550
+ flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0.
551
+ niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None.
552
+ interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
553
+
554
+ Returns:
555
+ A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels;
556
+ flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration;
557
+ styles: style vector summarizing each image of size 256;
558
+ imgs: Restored images.
559
+ """
560
+
561
+ if isinstance(normalize, dict):
562
+ normalize_params = {**normalize_default, **normalize}
563
+ elif not isinstance(normalize, bool):
564
+ raise ValueError("normalize parameter must be a bool or a dict")
565
+ else:
566
+ normalize_params = normalize_default
567
+ normalize_params["normalize"] = normalize
568
+ normalize_params["invert"] = invert
569
+
570
+ img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels,
571
+ channel_axis=channel_axis, z_axis=z_axis,
572
+ do_3D=do_3D,
573
+ normalize=normalize_params, rescale=rescale,
574
+ diameter=diameter,
575
+ tile_overlap=tile_overlap, bsize=bsize)
576
+
577
+ # turn off special normalization for segmentation
578
+ normalize_params = normalize_default
579
+
580
+ # change channels for segmentation
581
+ if channels is not None:
582
+ channels_new = [0, 0] if channels[0] == 0 else [1, 2]
583
+ else:
584
+ channels_new = None
585
+ # change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean)
586
+ diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
587
+ masks, flows, styles = self.cp.eval(
588
+ img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1,
589
+ z_axis=0 if not isinstance(img_restore, list) and img_restore.ndim > 3 and img_restore.shape[0] > 0 else None,
590
+ normalize=normalize_params, rescale=rescale, diameter=diameter,
591
+ tile_overlap=tile_overlap, augment=augment, resample=resample,
592
+ invert=invert, flow_threshold=flow_threshold,
593
+ cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy,
594
+ stitch_threshold=stitch_threshold, min_size=min_size, niter=niter,
595
+ interp=interp, bsize=bsize)
596
+
597
+ return masks, flows, styles, img_restore
598
+
599
+
600
+ class DenoiseModel():
601
+ """
602
+ DenoiseModel class for denoising images using Cellpose denoising model.
603
+
604
+ Args:
605
+ gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
606
+ pretrained_model (bool or str or Path, optional): Pretrained model to use for denoising.
607
+ Can be a string or path. Defaults to False.
608
+ nchan (int, optional): Number of channels in the input images, all Cellpose 3 models were trained with nchan=1. Defaults to 1.
609
+ model_type (str, optional): Type of pretrained model to use ("denoise_cyto3", "deblur_cyto3", "upsample_cyto3", ...). Defaults to None.
610
+ chan2 (bool, optional): Whether to use a separate model for the second channel. Defaults to False.
611
+ diam_mean (float, optional): Mean diameter of the objects in the images. Defaults to 30.0.
612
+ device (torch.device, optional): Device to use for computation. Defaults to None.
613
+
614
+ Attributes:
615
+ nchan (int): Number of channels in the input images.
616
+ diam_mean (float): Mean diameter of the objects in the images.
617
+ net (CPnet): Cellpose network for denoising.
618
+ pretrained_model (bool or str or Path): Pretrained model path to use for denoising.
619
+ net_chan2 (CPnet or None): Cellpose network for the second channel, if applicable.
620
+ net_type (str): Type of the denoising network.
621
+
622
+ Methods:
623
+ eval(x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
624
+ normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1)
625
+ Denoise array or list of images using the denoising model.
626
+
627
+ _eval(net, x, normalize=True, rescale=None, diameter=None, tile=True,
628
+ tile_overlap=0.1)
629
+ Run denoising model on a single channel.
630
+ """
631
+
632
+ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
633
+ chan2=False, diam_mean=30., device=None):
634
+ self.nchan = nchan
635
+ if pretrained_model and (not isinstance(pretrained_model, str) and
636
+ not isinstance(pretrained_model, Path)):
637
+ raise ValueError("pretrained_model must be a string or path")
638
+
639
+ self.diam_mean = diam_mean
640
+ builtin = True
641
+ if model_type is not None or (pretrained_model and
642
+ not os.path.exists(pretrained_model)):
643
+ pretrained_model_string = model_type if model_type is not None else "denoise_cyto3"
644
+ if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]):
645
+ pretrained_model_string = "denoise_cyto3"
646
+ pretrained_model = model_path(pretrained_model_string)
647
+ if (pretrained_model and not os.path.exists(pretrained_model)):
648
+ denoise_logger.warning("pretrained model has incorrect path")
649
+ denoise_logger.info(f">> {pretrained_model_string} << model set to be used")
650
+ self.diam_mean = 17. if "nuclei" in pretrained_model_string else 30.
651
+ else:
652
+ if pretrained_model:
653
+ builtin = False
654
+ pretrained_model_string = pretrained_model
655
+ denoise_logger.info(f">>>> loading model {pretrained_model_string}")
656
+
657
+ # assign network device
658
+ if device is None:
659
+ sdevice, gpu = assign_device(use_torch=True, gpu=gpu)
660
+ self.device = device if device is not None else sdevice
661
+ if device is not None:
662
+ device_gpu = self.device.type == "cuda"
663
+ self.gpu = gpu if device is None else device_gpu
664
+
665
+ # create network
666
+ self.nchan = nchan
667
+ self.nclasses = 1
668
+ nbase = [32, 64, 128, 256]
669
+ self.nchan = nchan
670
+ self.nbase = [nchan, *nbase]
671
+
672
+ self.net = CPnet(self.nbase, self.nclasses, sz=3,
673
+ max_pool=True, diam_mean=diam_mean).to(self.device)
674
+
675
+ self.pretrained_model = pretrained_model
676
+ self.net_chan2 = None
677
+ if self.pretrained_model:
678
+ self.net.load_model(self.pretrained_model, device=self.device)
679
+ denoise_logger.info(
680
+ f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)"
681
+ )
682
+ if chan2 and builtin:
683
+ chan2_path = model_path(
684
+ os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei")
685
+ print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}")
686
+ self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3,
687
+ max_pool=True,
688
+ diam_mean=17.).to(self.device)
689
+ self.net_chan2.load_model(chan2_path, device=self.device)
690
+ self.net_type = "cellpose_denoise"
691
+
692
+ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
693
+ normalize=True, rescale=None, diameter=None, tile=True, do_3D=False,
694
+ tile_overlap=0.1, bsize=224):
695
+ """
696
+ Restore array or list of images using the image restoration model.
697
+
698
+ Args:
699
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
700
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
701
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
702
+ channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
703
+ First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
704
+ Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
705
+ For instance, to segment grayscale images, input [0,0]. To segment images with cells
706
+ in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
707
+ image with cells in green and nuclei in blue, input [[0,0], [2,3]].
708
+ Defaults to None.
709
+ channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
710
+ if None, channels dimension is attempted to be automatically determined. Defaults to None.
711
+ z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
712
+ if None, z dimension is attempted to be automatically determined. Defaults to None.
713
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
714
+ can also pass dictionary of parameters (all keys are optional, default values shown):
715
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
716
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
717
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
718
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
719
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
720
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
721
+ Defaults to True.
722
+ rescale (float, optional): resize factor for each image, if None, set to 1.0;
723
+ (only used if diameter is None). Defaults to None.
724
+ diameter (float, optional): diameter for each image,
725
+ if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
726
+ tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
727
+
728
+ Returns:
729
+ list: A list of 2D/3D arrays of restored images
730
+
731
+ """
732
+ if isinstance(x, list) or x.squeeze().ndim == 5:
733
+ tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO)
734
+ nimg = len(x)
735
+ iterator = trange(nimg, file=tqdm_out,
736
+ mininterval=30) if nimg > 1 else range(nimg)
737
+ imgs = []
738
+ for i in iterator:
739
+ imgi = self.eval(
740
+ x[i], batch_size=batch_size,
741
+ channels=channels[i] if channels is not None and
742
+ ((len(channels) == len(x) and
743
+ (isinstance(channels[i], list) or
744
+ isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2))
745
+ else channels, channel_axis=channel_axis, z_axis=z_axis,
746
+ normalize=normalize,
747
+ do_3D=do_3D,
748
+ rescale=rescale[i] if isinstance(rescale, list) or
749
+ isinstance(rescale, np.ndarray) else rescale,
750
+ diameter=diameter[i] if isinstance(diameter, list) or
751
+ isinstance(diameter, np.ndarray) else diameter,
752
+ tile_overlap=tile_overlap, bsize=bsize)
753
+ imgs.append(imgi)
754
+ if isinstance(x, np.ndarray):
755
+ imgs = np.array(imgs)
756
+ return imgs
757
+
758
+ else:
759
+ # reshape image
760
+ x = transforms.convert_image(x, channels, channel_axis=channel_axis,
761
+ z_axis=z_axis, do_3D=do_3D, nchan=None)
762
+ if x.ndim < 4:
763
+ squeeze = True
764
+ x = x[np.newaxis, ...]
765
+ else:
766
+ squeeze = False
767
+
768
+ # may need to interpolate image before running upsampling
769
+ self.ratio = 1.
770
+ if "upsample" in self.pretrained_model:
771
+ Ly, Lx = x.shape[-3:-1]
772
+ if diameter is not None and 3 <= diameter < self.diam_mean:
773
+ self.ratio = self.diam_mean / diameter
774
+ denoise_logger.info(
775
+ f"upsampling image to {self.diam_mean} pixel diameter ({self.ratio:0.2f} times)"
776
+ )
777
+ Lyr, Lxr = int(Ly * self.ratio), int(Lx * self.ratio)
778
+ x = transforms.resize_image(x, Ly=Lyr, Lx=Lxr)
779
+ else:
780
+ denoise_logger.warning(
781
+ f"not interpolating image before upsampling because diameter is set >= {self.diam_mean}"
782
+ )
783
+ #raise ValueError(f"diameter is set to {diameter}, needs to be >=3 and < {self.dn.diam_mean}")
784
+
785
+ self.batch_size = batch_size
786
+
787
+ if diameter is not None and diameter > 0:
788
+ rescale = self.diam_mean / diameter
789
+ elif rescale is None:
790
+ rescale = 1.0
791
+
792
+ if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0):
793
+ x = x[..., :1]
794
+
795
+ for c in range(x.shape[-1]):
796
+ rescale0 = rescale * 30. / 17. if c == 1 else rescale
797
+ if c == 0 or self.net_chan2 is None:
798
+ x[...,
799
+ c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size,
800
+ normalize=normalize, rescale=rescale0,
801
+ tile_overlap=tile_overlap, bsize=bsize)[...,0]
802
+ else:
803
+ x[...,
804
+ c] = self._eval(self.net_chan2, x[...,
805
+ c:c + 1], batch_size=batch_size,
806
+ normalize=normalize, rescale=rescale0,
807
+ tile_overlap=tile_overlap, bsize=bsize)[...,0]
808
+ x = x[0] if squeeze else x
809
+ return x
810
+
811
+ def _eval(self, net, x, batch_size=8, normalize=True, rescale=None,
812
+ tile_overlap=0.1, bsize=224):
813
+ """
814
+ Run image restoration model on a single channel.
815
+
816
+ Args:
817
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
818
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
819
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
820
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
821
+ can also pass dictionary of parameters (all keys are optional, default values shown):
822
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
823
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
824
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
825
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
826
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
827
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
828
+ Defaults to True.
829
+ rescale (float, optional): resize factor for each image, if None, set to 1.0;
830
+ (only used if diameter is None). Defaults to None.
831
+ tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
832
+
833
+ Returns:
834
+ list: A list of 2D/3D arrays of restored images
835
+
836
+ """
837
+ if isinstance(normalize, dict):
838
+ normalize_params = {**normalize_default, **normalize}
839
+ elif not isinstance(normalize, bool):
840
+ raise ValueError("normalize parameter must be a bool or a dict")
841
+ else:
842
+ normalize_params = normalize_default
843
+ normalize_params["normalize"] = normalize
844
+
845
+ tic = time.time()
846
+ shape = x.shape
847
+ nimg = shape[0]
848
+
849
+ do_normalization = True if normalize_params["normalize"] else False
850
+
851
+ img = np.asarray(x)
852
+ if do_normalization:
853
+ img = transforms.normalize_img(img, **normalize_params)
854
+ if rescale != 1.0:
855
+ img = transforms.resize_image(img, rsz=rescale)
856
+ yf, style = run_net(self.net, img, bsize=bsize,
857
+ tile_overlap=tile_overlap)
858
+ yf = transforms.resize_image(yf, shape[1], shape[2])
859
+ imgs = yf
860
+ del yf, style
861
+
862
+ # imgs = np.zeros((*x.shape[:-1], 1), np.float32)
863
+ # for i in iterator:
864
+ # img = np.asarray(x[i])
865
+ # if do_normalization:
866
+ # img = transforms.normalize_img(img, **normalize_params)
867
+ # if rescale != 1.0:
868
+ # img = transforms.resize_image(img, rsz=[rescale, rescale])
869
+ # if img.ndim == 2:
870
+ # img = img[:, :, np.newaxis]
871
+ # yf, style = run_net(net, img, batch_size=batch_size, augment=False,
872
+ # tile=tile, tile_overlap=tile_overlap, bsize=bsize)
873
+ # img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2])
874
+
875
+ # if img.ndim == 2:
876
+ # img = img[:, :, np.newaxis]
877
+ # imgs[i] = img
878
+ # del yf, style
879
+ net_time = time.time() - tic
880
+ if nimg > 1:
881
+ denoise_logger.info("imgs denoised in %2.2fs" % (net_time))
882
+
883
+ return imgs
884
+
885
+
886
+ def train(net, train_data=None, train_labels=None, train_files=None, test_data=None,
887
+ test_labels=None, test_files=None, train_probs=None, test_probs=None,
888
+ lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None,
889
+ save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0,
890
+ iso=True, uniform_blur=False, downsample=0., ds_max=7,
891
+ learning_rate=0.005, n_epochs=500,
892
+ weight_decay=0.00001, batch_size=8, nimg_per_epoch=None,
893
+ nimg_test_per_epoch=None, model_name=None):
894
+
895
+ # net properties
896
+ device = net.device
897
+ nchan = net.nchan
898
+ diam_mean = net.diam_mean.item()
899
+
900
+ args = np.array([poisson, beta, blur, gblur, downsample])
901
+ if args.ndim == 1:
902
+ args = args[:, np.newaxis]
903
+ poisson, beta, blur, gblur, downsample = args
904
+ nnoise = len(poisson)
905
+
906
+ d = datetime.datetime.now()
907
+ if save_path is not None:
908
+ if model_name is None:
909
+ filename = ""
910
+ lstrs = ["per", "seg", "rec"]
911
+ for k, (l, s) in enumerate(zip(lam, lstrs)):
912
+ filename += f"{s}_{l:.2f}_"
913
+ if not iso:
914
+ filename += "aniso_"
915
+ if poisson.sum() > 0:
916
+ filename += "poisson_"
917
+ if blur.sum() > 0:
918
+ filename += "blur_"
919
+ if downsample.sum() > 0:
920
+ filename += "downsample_"
921
+ filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f")
922
+ filename = os.path.join(save_path, filename)
923
+ else:
924
+ filename = os.path.join(save_path, model_name)
925
+ print(filename)
926
+ for i in range(len(poisson)):
927
+ denoise_logger.info(
928
+ f"poisson: {poisson[i]: 0.2f}, beta: {beta[i]: 0.2f}, blur: {blur[i]: 0.2f}, gblur: {gblur[i]: 0.2f}, downsample: {downsample[i]: 0.2f}"
929
+ )
930
+ net1 = one_chan_cellpose(device=device, pretrained_model=seg_model_type)
931
+
932
+ learning_rate_const = learning_rate
933
+ LR = np.linspace(0, learning_rate_const, 10)
934
+ LR = np.append(LR, learning_rate_const * np.ones(n_epochs - 100))
935
+ for i in range(10):
936
+ LR = np.append(LR, LR[-1] / 2 * np.ones(10))
937
+ learning_rate = LR
938
+
939
+ batch_size = 8
940
+ optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate[0],
941
+ weight_decay=weight_decay)
942
+ if train_data is not None:
943
+ nimg = len(train_data)
944
+ diam_train = np.array(
945
+ [utils.diameters(train_labels[k])[0] for k in trange(len(train_labels))])
946
+ diam_train[diam_train < 5] = 5.
947
+ if test_data is not None:
948
+ diam_test = np.array(
949
+ [utils.diameters(test_labels[k])[0] for k in trange(len(test_labels))])
950
+ diam_test[diam_test < 5] = 5.
951
+ nimg_test = len(test_data)
952
+ else:
953
+ nimg = len(train_files)
954
+ denoise_logger.info(">>> using files instead of loading dataset")
955
+ train_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in train_files]
956
+ denoise_logger.info(">>> computing diameters")
957
+ diam_train = np.array([
958
+ utils.diameters(io.imread(train_labels_files[k])[0])[0]
959
+ for k in trange(len(train_labels_files))
960
+ ])
961
+ diam_train[diam_train < 5] = 5.
962
+ if test_files is not None:
963
+ nimg_test = len(test_files)
964
+ test_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in test_files]
965
+ diam_test = np.array([
966
+ utils.diameters(io.imread(test_labels_files[k])[0])[0]
967
+ for k in trange(len(test_labels_files))
968
+ ])
969
+ diam_test[diam_test < 5] = 5.
970
+ train_probs = 1. / nimg * np.ones(nimg,
971
+ "float64") if train_probs is None else train_probs
972
+ if test_files is not None or test_data is not None:
973
+ test_probs = 1. / nimg_test * np.ones(
974
+ nimg_test, "float64") if test_probs is None else test_probs
975
+
976
+ tic = time.time()
977
+
978
+ nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
979
+ if test_files is not None or test_data is not None:
980
+ nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
981
+
982
+ nbatch = 0
983
+ train_losses, test_losses = [], []
984
+ for iepoch in range(n_epochs):
985
+ np.random.seed(iepoch)
986
+ rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
987
+ p=train_probs)
988
+ torch.manual_seed(iepoch)
989
+ np.random.seed(iepoch)
990
+ for param_group in optimizer.param_groups:
991
+ param_group["lr"] = learning_rate[iepoch]
992
+ lavg, lavg_per, nsum = 0, 0, 0
993
+ for ibatch in range(0, nimg_per_epoch, batch_size * nnoise):
994
+ inds = rperm[ibatch : ibatch + batch_size * nnoise]
995
+ if train_data is None:
996
+ imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds]
997
+ lbls = [io.imread(train_labels_files[i])[1:] for i in inds]
998
+ else:
999
+ imgs = [train_data[i][:nchan] for i in inds]
1000
+ lbls = [train_labels[i][1:] for i in inds]
1001
+ #inoise = nbatch % nnoise
1002
+ rnoise = np.random.permutation(nnoise)
1003
+ for i, inoise in enumerate(rnoise):
1004
+ if i * batch_size < len(imgs):
1005
+ imgi, lbli, scale = random_rotate_and_resize_noise(
1006
+ imgs[i * batch_size : (i + 1) * batch_size],
1007
+ lbls[i * batch_size : (i + 1) * batch_size],
1008
+ diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(),
1009
+ poisson=poisson[inoise],
1010
+ beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso,
1011
+ downsample=downsample[inoise], uniform_blur=uniform_blur,
1012
+ diam_mean=diam_mean, ds_max=ds_max,
1013
+ device=device)
1014
+ if i == 0:
1015
+ img = imgi
1016
+ lbl = lbli
1017
+ else:
1018
+ img = torch.cat((img, imgi), axis=0)
1019
+ lbl = torch.cat((lbl, lbli), axis=0)
1020
+
1021
+ if nnoise > 0:
1022
+ iperm = np.random.permutation(img.shape[0])
1023
+ img, lbl = img[iperm], lbl[iperm]
1024
+
1025
+ for i in range(nnoise):
1026
+ optimizer.zero_grad()
1027
+ imgi = img[i * batch_size: (i + 1) * batch_size]
1028
+ lbli = lbl[i * batch_size: (i + 1) * batch_size]
1029
+ if imgi.shape[0] > 0:
1030
+ loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1,
1031
+ img=imgi[:, nchan:], lbl=lbli, lam=lam)
1032
+ loss.backward()
1033
+ optimizer.step()
1034
+ lavg += loss.item() * imgi.shape[0]
1035
+ lavg_per += loss_per.item() * imgi.shape[0]
1036
+
1037
+ nsum += len(img)
1038
+ nbatch += 1
1039
+
1040
+ if iepoch % 5 == 0 or iepoch < 10:
1041
+ lavg = lavg / nsum
1042
+ lavg_per = lavg_per / nsum
1043
+ if test_data is not None or test_files is not None:
1044
+ lavgt, nsum = 0., 0
1045
+ np.random.seed(42)
1046
+ rperm = np.random.choice(np.arange(0, nimg_test),
1047
+ size=(nimg_test_per_epoch,), p=test_probs)
1048
+ inoise = iepoch % nnoise
1049
+ torch.manual_seed(inoise)
1050
+ for ibatch in range(0, nimg_test_per_epoch, batch_size):
1051
+ inds = rperm[ibatch:ibatch + batch_size]
1052
+ if test_data is None:
1053
+ imgs = [
1054
+ np.maximum(0,
1055
+ io.imread(test_files[i])[:nchan]) for i in inds
1056
+ ]
1057
+ lbls = [io.imread(test_labels_files[i])[1:] for i in inds]
1058
+ else:
1059
+ imgs = [test_data[i][:nchan] for i in inds]
1060
+ lbls = [test_labels[i][1:] for i in inds]
1061
+ img, lbl, scale = random_rotate_and_resize_noise(
1062
+ imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise],
1063
+ beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise],
1064
+ iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur,
1065
+ diam_mean=diam_mean, ds_max=ds_max, device=device)
1066
+ loss, loss_per = test_loss(net, img[:, :nchan], net1=net1,
1067
+ img=img[:, nchan:], lbl=lbl, lam=lam)
1068
+
1069
+ lavgt += loss.item() * img.shape[0]
1070
+ nsum += len(img)
1071
+ lavgt = lavgt / nsum
1072
+ denoise_logger.info(
1073
+ "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f"
1074
+ % (iepoch, time.time() - tic, lavg, lavg_per, lavgt,
1075
+ learning_rate[iepoch]))
1076
+ test_losses.append(lavgt)
1077
+ else:
1078
+ denoise_logger.info(
1079
+ "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" %
1080
+ (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch]))
1081
+ train_losses.append(lavg)
1082
+
1083
+ if save_path is not None:
1084
+ if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
1085
+ if save_each: #separate files as model progresses
1086
+ filename0 = str(filename) + f"_epoch_{iepoch:%04d}"
1087
+ else:
1088
+ filename0 = filename
1089
+ denoise_logger.info(f"saving network parameters to {filename0}")
1090
+ net.save_model(filename0)
1091
+ else:
1092
+ filename = save_path
1093
+
1094
+ return filename, train_losses, test_losses
1095
+
1096
+
1097
+ if __name__ == "__main__":
1098
+ import argparse
1099
+ parser = argparse.ArgumentParser(description="cellpose parameters")
1100
+
1101
+ input_img_args = parser.add_argument_group("input image arguments")
1102
+ input_img_args.add_argument("--dir", default=[], type=str,
1103
+ help="folder containing data to run or train on.")
1104
+ input_img_args.add_argument("--img_filter", default=[], type=str,
1105
+ help="end string for images to run on")
1106
+
1107
+ model_args = parser.add_argument_group("model arguments")
1108
+ model_args.add_argument("--pretrained_model", default=[], type=str,
1109
+ help="pretrained denoising model")
1110
+
1111
+ training_args = parser.add_argument_group("training arguments")
1112
+ training_args.add_argument("--test_dir", default=[], type=str,
1113
+ help="folder containing test data (optional)")
1114
+ training_args.add_argument("--file_list", default=[], type=str,
1115
+ help="npy file containing list of train and test files")
1116
+ training_args.add_argument("--seg_model_type", default="cyto2", type=str,
1117
+ help="model to use for seg training loss")
1118
+ training_args.add_argument(
1119
+ "--noise_type", default=[], type=str,
1120
+ help="noise type to use (if input, then other noise params are ignored)")
1121
+ training_args.add_argument("--poisson", default=0.8, type=float,
1122
+ help="fraction of images to add poisson noise to")
1123
+ training_args.add_argument("--beta", default=0.7, type=float,
1124
+ help="scale of poisson noise")
1125
+ training_args.add_argument("--blur", default=0., type=float,
1126
+ help="fraction of images to blur")
1127
+ training_args.add_argument("--gblur", default=1.0, type=float,
1128
+ help="scale of gaussian blurring stddev")
1129
+ training_args.add_argument("--downsample", default=0., type=float,
1130
+ help="fraction of images to downsample")
1131
+ training_args.add_argument("--ds_max", default=7, type=int,
1132
+ help="max downsampling factor")
1133
+ training_args.add_argument("--lam_per", default=1.0, type=float,
1134
+ help="weighting of perceptual loss")
1135
+ training_args.add_argument("--lam_seg", default=1.5, type=float,
1136
+ help="weighting of segmentation loss")
1137
+ training_args.add_argument("--lam_rec", default=0., type=float,
1138
+ help="weighting of reconstruction loss")
1139
+ training_args.add_argument(
1140
+ "--diam_mean", default=30., type=float, help=
1141
+ "mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0"
1142
+ )
1143
+ training_args.add_argument("--learning_rate", default=0.001, type=float,
1144
+ help="learning rate. Default: %(default)s")
1145
+ training_args.add_argument("--n_epochs", default=2000, type=int,
1146
+ help="number of epochs. Default: %(default)s")
1147
+ training_args.add_argument(
1148
+ "--save_each", default=False, action="store_true",
1149
+ help="save each epoch as separate model")
1150
+ training_args.add_argument(
1151
+ "--nimg_per_epoch", default=0, type=int,
1152
+ help="number of images per epoch. Default is length of training images")
1153
+ training_args.add_argument(
1154
+ "--nimg_test_per_epoch", default=0, type=int,
1155
+ help="number of test images per epoch. Default is length of testing images")
1156
+
1157
+ io.logger_setup()
1158
+
1159
+ args = parser.parse_args()
1160
+ lams = [args.lam_per, args.lam_seg, args.lam_rec]
1161
+ print("lam", lams)
1162
+
1163
+ if len(args.noise_type) > 0:
1164
+ noise_type = args.noise_type
1165
+ uniform_blur = False
1166
+ iso = True
1167
+ if noise_type == "poisson":
1168
+ poisson = 0.8
1169
+ blur = 0.
1170
+ downsample = 0.
1171
+ beta = 0.7
1172
+ gblur = 1.0
1173
+ elif noise_type == "blur_expr":
1174
+ poisson = 0.8
1175
+ blur = 0.8
1176
+ downsample = 0.
1177
+ beta = 0.1
1178
+ gblur = 0.5
1179
+ elif noise_type == "blur":
1180
+ poisson = 0.8
1181
+ blur = 0.8
1182
+ downsample = 0.
1183
+ beta = 0.1
1184
+ gblur = 10.0
1185
+ uniform_blur = True
1186
+ elif noise_type == "downsample_expr":
1187
+ poisson = 0.8
1188
+ blur = 0.8
1189
+ downsample = 0.8
1190
+ beta = 0.03
1191
+ gblur = 1.0
1192
+ elif noise_type == "downsample":
1193
+ poisson = 0.8
1194
+ blur = 0.8
1195
+ downsample = 0.8
1196
+ beta = 0.03
1197
+ gblur = 5.0
1198
+ uniform_blur = True
1199
+ elif noise_type == "all":
1200
+ poisson = [0.8, 0.8, 0.8]
1201
+ blur = [0., 0.8, 0.8]
1202
+ downsample = [0., 0., 0.8]
1203
+ beta = [0.7, 0.1, 0.03]
1204
+ gblur = [0., 10.0, 5.0]
1205
+ uniform_blur = True
1206
+ elif noise_type == "aniso":
1207
+ poisson = 0.8
1208
+ blur = 0.8
1209
+ downsample = 0.8
1210
+ beta = 0.1
1211
+ gblur = args.ds_max * 1.5
1212
+ iso = False
1213
+ else:
1214
+ raise ValueError(f"{noise_type} noise_type is not supported")
1215
+ else:
1216
+ poisson, beta = args.poisson, args.beta
1217
+ blur, gblur = args.blur, args.gblur
1218
+ downsample = args.downsample
1219
+
1220
+ pretrained_model = None if len(
1221
+ args.pretrained_model) == 0 else args.pretrained_model
1222
+ model = DenoiseModel(gpu=True, nchan=1, diam_mean=args.diam_mean,
1223
+ pretrained_model=pretrained_model)
1224
+
1225
+ train_data, labels, train_files, train_probs = None, None, None, None
1226
+ test_data, test_labels, test_files, test_probs = None, None, None, None
1227
+ if len(args.file_list) == 0:
1228
+ output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0)
1229
+ images, labels, image_names, test_images, test_labels, image_names_test = output
1230
+ train_data = []
1231
+ for i in range(len(images)):
1232
+ img = images[i].astype("float32")
1233
+ if img.ndim > 2:
1234
+ img = img[0]
1235
+ train_data.append(
1236
+ np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
1237
+ if len(args.test_dir) > 0:
1238
+ test_data = []
1239
+ for i in range(len(test_images)):
1240
+ img = test_images[i].astype("float32")
1241
+ if img.ndim > 2:
1242
+ img = img[0]
1243
+ test_data.append(
1244
+ np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
1245
+ save_path = os.path.join(args.dir, "../models/")
1246
+ else:
1247
+ root = args.dir
1248
+ denoise_logger.info(
1249
+ ">>> using file_list (assumes images are normalized and have flows!)")
1250
+ dat = np.load(args.file_list, allow_pickle=True).item()
1251
+ train_files = dat["train_files"]
1252
+ test_files = dat["test_files"]
1253
+ train_probs = dat["train_probs"] if "train_probs" in dat else None
1254
+ test_probs = dat["test_probs"] if "test_probs" in dat else None
1255
+ if str(train_files[0])[:len(str(root))] != str(root):
1256
+ for i in range(len(train_files)):
1257
+ new_path = root / Path(*train_files[i].parts[-3:])
1258
+ if i == 0:
1259
+ print(f"changing path from {train_files[i]} to {new_path}")
1260
+ train_files[i] = new_path
1261
+
1262
+ for i in range(len(test_files)):
1263
+ new_path = root / Path(*test_files[i].parts[-3:])
1264
+ test_files[i] = new_path
1265
+ save_path = os.path.join(args.dir, "models/")
1266
+
1267
+ os.makedirs(save_path, exist_ok=True)
1268
+
1269
+ nimg_per_epoch = None if args.nimg_per_epoch == 0 else args.nimg_per_epoch
1270
+ nimg_test_per_epoch = None if args.nimg_test_per_epoch == 0 else args.nimg_test_per_epoch
1271
+
1272
+ model_path = train(
1273
+ model.net, train_data=train_data, train_labels=labels, train_files=train_files,
1274
+ test_data=test_data, test_labels=test_labels, test_files=test_files,
1275
+ train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta,
1276
+ blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max,
1277
+ iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs,
1278
+ learning_rate=args.learning_rate,
1279
+ lam=lams,
1280
+ seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch,
1281
+ nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path)
1282
+
1283
+
1284
+ def seg_train_noisy(model, train_data, train_labels, test_data=None, test_labels=None,
1285
+ poisson=0.8, blur=0.0, downsample=0.0, save_path=None,
1286
+ save_every=100, save_each=False, learning_rate=0.2, n_epochs=500,
1287
+ momentum=0.9, weight_decay=0.00001, SGD=True, batch_size=8,
1288
+ nimg_per_epoch=None, diameter=None, rescale=True, z_masking=False,
1289
+ model_name=None):
1290
+ """ train function uses loss function model.loss_fn in models.py
1291
+
1292
+ (data should already be normalized)
1293
+
1294
+ """
1295
+
1296
+ d = datetime.datetime.now()
1297
+
1298
+ model.n_epochs = n_epochs
1299
+ if isinstance(learning_rate, (list, np.ndarray)):
1300
+ if isinstance(learning_rate, np.ndarray) and learning_rate.ndim > 1:
1301
+ raise ValueError("learning_rate.ndim must equal 1")
1302
+ elif len(learning_rate) != n_epochs:
1303
+ raise ValueError(
1304
+ "if learning_rate given as list or np.ndarray it must have length n_epochs"
1305
+ )
1306
+ model.learning_rate = learning_rate
1307
+ model.learning_rate_const = mode(learning_rate)[0][0]
1308
+ else:
1309
+ model.learning_rate_const = learning_rate
1310
+ # set learning rate schedule
1311
+ if SGD:
1312
+ LR = np.linspace(0, model.learning_rate_const, 10)
1313
+ if model.n_epochs > 250:
1314
+ LR = np.append(
1315
+ LR, model.learning_rate_const * np.ones(model.n_epochs - 100))
1316
+ for i in range(10):
1317
+ LR = np.append(LR, LR[-1] / 2 * np.ones(10))
1318
+ else:
1319
+ LR = np.append(
1320
+ LR,
1321
+ model.learning_rate_const * np.ones(max(0, model.n_epochs - 10)))
1322
+ else:
1323
+ LR = model.learning_rate_const * np.ones(model.n_epochs)
1324
+ model.learning_rate = LR
1325
+
1326
+ model.batch_size = batch_size
1327
+ model._set_optimizer(model.learning_rate[0], momentum, weight_decay, SGD)
1328
+ model._set_criterion()
1329
+
1330
+ nimg = len(train_data)
1331
+
1332
+ # compute average cell diameter
1333
+ if diameter is None:
1334
+ diam_train = np.array(
1335
+ [utils.diameters(train_labels[k][0])[0] for k in range(len(train_labels))])
1336
+ diam_train_mean = diam_train[diam_train > 0].mean()
1337
+ model.diam_labels = diam_train_mean
1338
+ if rescale:
1339
+ diam_train[diam_train < 5] = 5.
1340
+ if test_data is not None:
1341
+ diam_test = np.array([
1342
+ utils.diameters(test_labels[k][0])[0]
1343
+ for k in range(len(test_labels))
1344
+ ])
1345
+ diam_test[diam_test < 5] = 5.
1346
+ denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
1347
+ elif rescale:
1348
+ diam_train_mean = diameter
1349
+ model.diam_labels = diameter
1350
+ denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
1351
+ diam_train = diameter * np.ones(len(train_labels), "float32")
1352
+ if test_data is not None:
1353
+ diam_test = diameter * np.ones(len(test_labels), "float32")
1354
+
1355
+ denoise_logger.info(
1356
+ f">>>> mean of training label mask diameters (saved to model) {diam_train_mean:.3f}"
1357
+ )
1358
+ model.net.diam_labels.data = torch.ones(1, device=model.device) * diam_train_mean
1359
+
1360
+ nchan = train_data[0].shape[0]
1361
+ denoise_logger.info(">>>> training network with %d channel input <<<<" % nchan)
1362
+ denoise_logger.info(">>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f" %
1363
+ (model.learning_rate_const, model.batch_size, weight_decay))
1364
+
1365
+ if test_data is not None:
1366
+ denoise_logger.info(f">>>> ntrain = {nimg}, ntest = {len(test_data)}")
1367
+ else:
1368
+ denoise_logger.info(f">>>> ntrain = {nimg}")
1369
+
1370
+ tic = time.time()
1371
+
1372
+ lavg, nsum = 0, 0
1373
+
1374
+ if save_path is not None:
1375
+ _, file_label = os.path.split(save_path)
1376
+ file_path = os.path.join(save_path, "models/")
1377
+
1378
+ if not os.path.exists(file_path):
1379
+ os.makedirs(file_path)
1380
+ else:
1381
+ denoise_logger.warning("WARNING: no save_path given, model not saving")
1382
+
1383
+ ksave = 0
1384
+
1385
+ # get indices for each epoch for training
1386
+ np.random.seed(0)
1387
+ inds_all = np.zeros((0,), "int32")
1388
+ if nimg_per_epoch is None or nimg > nimg_per_epoch:
1389
+ nimg_per_epoch = nimg
1390
+ denoise_logger.info(f">>>> nimg_per_epoch = {nimg_per_epoch}")
1391
+ while len(inds_all) < n_epochs * nimg_per_epoch:
1392
+ rperm = np.random.permutation(nimg)
1393
+ inds_all = np.hstack((inds_all, rperm))
1394
+
1395
+ for iepoch in range(model.n_epochs):
1396
+ if SGD:
1397
+ model._set_learning_rate(model.learning_rate[iepoch])
1398
+ np.random.seed(iepoch)
1399
+ rperm = inds_all[iepoch * nimg_per_epoch:(iepoch + 1) * nimg_per_epoch]
1400
+ for ibatch in range(0, nimg_per_epoch, batch_size):
1401
+ inds = rperm[ibatch:ibatch + batch_size]
1402
+ imgi, lbl, scale = random_rotate_and_resize_noise(
1403
+ [train_data[i] for i in inds], [train_labels[i][1:] for i in inds],
1404
+ poisson=poisson, blur=blur, downsample=downsample,
1405
+ diams=diam_train[inds], diam_mean=model.diam_mean)
1406
+ imgi = imgi[:, :1] # keep noisy only
1407
+ if z_masking:
1408
+ nc = imgi.shape[1]
1409
+ nb = imgi.shape[0]
1410
+ ncmin = (np.random.rand(nb) > 0.25) * (np.random.randint(
1411
+ nc // 2 - 1, size=nb))
1412
+ ncmax = nc - (np.random.rand(nb) > 0.25) * (np.random.randint(
1413
+ nc // 2 - 1, size=nb))
1414
+ for b in range(nb):
1415
+ imgi[b, :ncmin[b]] = 0
1416
+ imgi[b, ncmax[b]:] = 0
1417
+
1418
+ train_loss = model._train_step(imgi, lbl)
1419
+ lavg += train_loss
1420
+ nsum += len(imgi)
1421
+
1422
+ if iepoch % 10 == 0 or iepoch == 5:
1423
+ lavg = lavg / nsum
1424
+ if test_data is not None:
1425
+ lavgt, nsum = 0., 0
1426
+ np.random.seed(42)
1427
+ rperm = np.arange(0, len(test_data), 1, int)
1428
+ for ibatch in range(0, len(test_data), batch_size):
1429
+ inds = rperm[ibatch:ibatch + batch_size]
1430
+ imgi, lbl, scale = random_rotate_and_resize_noise(
1431
+ [test_data[i] for i in inds],
1432
+ [test_labels[i][1:] for i in inds], poisson=poisson, blur=blur,
1433
+ downsample=downsample, diams=diam_test[inds],
1434
+ diam_mean=model.diam_mean)
1435
+ imgi = imgi[:, :1] # keep noisy only
1436
+ test_loss = model._test_eval(imgi, lbl)
1437
+ lavgt += test_loss
1438
+ nsum += len(imgi)
1439
+
1440
+ denoise_logger.info(
1441
+ "Epoch %d, Time %4.1fs, Loss %2.4f, Loss Test %2.4f, LR %2.4f" %
1442
+ (iepoch, time.time() - tic, lavg, lavgt / nsum,
1443
+ model.learning_rate[iepoch]))
1444
+ else:
1445
+ denoise_logger.info(
1446
+ "Epoch %d, Time %4.1fs, Loss %2.4f, LR %2.4f" %
1447
+ (iepoch, time.time() - tic, lavg, model.learning_rate[iepoch]))
1448
+
1449
+ lavg, nsum = 0, 0
1450
+
1451
+ if save_path is not None:
1452
+ if iepoch == model.n_epochs - 1 or iepoch % save_every == 1:
1453
+ # save model at the end
1454
+ if save_each: #separate files as model progresses
1455
+ if model_name is None:
1456
+ filename = "{}_{}_{}_{}".format(
1457
+ model.net_type, file_label,
1458
+ d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch))
1459
+ else:
1460
+ filename = "{}_{}".format(model_name, "epoch_" + str(iepoch))
1461
+ else:
1462
+ if model_name is None:
1463
+ filename = "{}_{}_{}".format(model.net_type, file_label,
1464
+ d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
1465
+ else:
1466
+ filename = model_name
1467
+ filename = os.path.join(file_path, filename)
1468
+ ksave += 1
1469
+ denoise_logger.info(f"saving network parameters to {filename}")
1470
+ model.net.save_model(filename)
1471
+ else:
1472
+ filename = save_path
1473
+
1474
+ return filename
models/seg_post_model/cellpose/dynamics.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import os
5
+ from scipy.ndimage import find_objects, center_of_mass, mean
6
+ import torch
7
+ import numpy as np
8
+ import tifffile
9
+ from tqdm import trange
10
+ import fastremap
11
+
12
+ import logging
13
+
14
+ dynamics_logger = logging.getLogger(__name__)
15
+
16
+ from . import utils
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+
21
+ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
22
+ device=torch.device("cpu")):
23
+ """Runs diffusion on GPU to generate flows for training images or quality control.
24
+
25
+ Args:
26
+ neighbors (torch.Tensor): 9 x pixels in masks.
27
+ meds (torch.Tensor): Mask centers.
28
+ isneighbor (torch.Tensor): Valid neighbor boolean 9 x pixels.
29
+ shape (tuple): Shape of the tensor.
30
+ n_iter (int, optional): Number of iterations. Defaults to 200.
31
+ device (torch.device, optional): Device to run the computation on. Defaults to torch.device("cpu").
32
+
33
+ Returns:
34
+ torch.Tensor: Generated flows.
35
+
36
+ """
37
+ if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps":
38
+ T = torch.zeros(shape, dtype=torch.float, device=device)
39
+ else:
40
+ T = torch.zeros(shape, dtype=torch.double, device=device)
41
+
42
+ for i in range(n_iter):
43
+ T[tuple(meds.T)] += 1
44
+ Tneigh = T[tuple(neighbors)]
45
+ Tneigh *= isneighbor
46
+ T[tuple(neighbors[:, 0])] = Tneigh.mean(axis=0)
47
+ del meds, isneighbor, Tneigh
48
+
49
+ if T.ndim == 2:
50
+ grads = T[neighbors[0, [2, 1, 4, 3]], neighbors[1, [2, 1, 4, 3]]]
51
+ del neighbors
52
+ dy = grads[0] - grads[1]
53
+ dx = grads[2] - grads[3]
54
+ del grads
55
+ mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
56
+ else:
57
+ grads = T[tuple(neighbors[:, 1:])]
58
+ del neighbors
59
+ dz = grads[0] - grads[1]
60
+ dy = grads[2] - grads[3]
61
+ dx = grads[4] - grads[5]
62
+ del grads
63
+ mu_torch = np.stack(
64
+ (dz.cpu().squeeze(0), dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
65
+ return mu_torch
66
+
67
+ def center_of_mass(mask):
68
+ yi, xi = np.nonzero(mask)
69
+ ymean = int(np.round(yi.sum() / len(yi)))
70
+ xmean = int(np.round(xi.sum() / len(xi)))
71
+ if not ((yi==ymean) * (xi==xmean)).sum():
72
+ # center is closest point to (ymean, xmean) within mask
73
+ imin = ((xi - xmean)**2 + (yi - ymean)**2).argmin()
74
+ ymean = yi[imin]
75
+ xmean = xi[imin]
76
+
77
+ return ymean, xmean
78
+
79
+ def get_centers(masks, slices):
80
+ centers = [center_of_mass(masks[slices[i]]==(i+1)) for i in range(len(slices))]
81
+ centers = np.array([np.array([centers[i][0] + slices[i][0].start, centers[i][1] + slices[i][1].start])
82
+ for i in range(len(slices))])
83
+ exts = np.array([(slc[0].stop - slc[0].start) + (slc[1].stop - slc[1].start) + 2 for slc in slices])
84
+ return centers, exts
85
+
86
+
87
+ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
88
+ """Convert masks to flows using diffusion from center pixel.
89
+
90
+ Center of masks where diffusion starts is defined by pixel closest to median within the mask.
91
+
92
+ Args:
93
+ masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
94
+ device (torch.device, optional): The device to run the computation on. Defaults to torch.device("cpu").
95
+ niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
96
+
97
+ Returns:
98
+ np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
99
+
100
+
101
+ Returns:
102
+ A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY.
103
+ meds_p are cell centers.
104
+ """
105
+ if device is None:
106
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
107
+
108
+ if masks.max() > 0:
109
+ Ly0, Lx0 = masks.shape
110
+ Ly, Lx = Ly0 + 2, Lx0 + 2
111
+
112
+ masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
113
+ masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
114
+ shape = masks_padded.shape
115
+
116
+ ### get mask pixel neighbors
117
+ y, x = torch.nonzero(masks_padded, as_tuple=True)
118
+ y = y.int()
119
+ x = x.int()
120
+ neighbors = torch.zeros((2, 9, y.shape[0]), dtype=torch.int, device=device)
121
+ yxi = [[0, -1, 1, 0, 0, -1, -1, 1, 1], [0, 0, 0, -1, 1, -1, 1, -1, 1]]
122
+ for i in range(9):
123
+ neighbors[0, i] = y + yxi[0][i]
124
+ neighbors[1, i] = x + yxi[1][i]
125
+ isneighbor = torch.ones((9, y.shape[0]), dtype=torch.bool, device=device)
126
+ m0 = masks_padded[neighbors[0, 0], neighbors[1, 0]]
127
+ for i in range(1, 9):
128
+ isneighbor[i] = masks_padded[neighbors[0, i], neighbors[1, i]] == m0
129
+ del m0, masks_padded
130
+
131
+ ### get center-of-mass within cell
132
+ slices = find_objects(masks)
133
+ centers, ext = get_centers(masks, slices)
134
+ meds_p = torch.from_numpy(centers).to(device).long()
135
+ meds_p += 1 # for padding
136
+
137
+ ### run diffusion
138
+ n_iter = 2 * ext.max() if niter is None else niter
139
+ mu = _extend_centers_gpu(neighbors, meds_p, isneighbor, shape, n_iter=n_iter,
140
+ device=device)
141
+ mu = mu.astype("float64")
142
+
143
+ # new normalization
144
+ mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
145
+
146
+ # put into original image
147
+ mu0 = np.zeros((2, Ly0, Lx0))
148
+ mu0[:, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
149
+ else:
150
+ # no masks, return empty flows
151
+ mu0 = np.zeros((2, masks.shape[0], masks.shape[1]))
152
+ return mu0
153
+
154
+ def masks_to_flows_gpu_3d(masks, device=None, niter=None):
155
+ """Convert masks to flows using diffusion from center pixel.
156
+
157
+ Args:
158
+ masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
159
+ device (torch.device, optional): The device to run the computation on. Defaults to None.
160
+ niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
161
+
162
+ Returns:
163
+ np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
164
+
165
+ """
166
+ if device is None:
167
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
168
+
169
+ Lz0, Ly0, Lx0 = masks.shape
170
+ Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2
171
+
172
+ masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
173
+ masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1))
174
+
175
+ # get mask pixel neighbors
176
+ z, y, x = torch.nonzero(masks_padded).T
177
+ neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z))
178
+ neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0)
179
+ neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0)
180
+
181
+ neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0)
182
+
183
+ # get mask centers
184
+ slices = find_objects(masks)
185
+
186
+ centers = np.zeros((masks.max(), 3), "int")
187
+ for i, si in enumerate(slices):
188
+ if si is not None:
189
+ sz, sy, sx = si
190
+ zi, yi, xi = np.nonzero(masks[sz, sy, sx] == (i + 1))
191
+ zi = zi.astype(np.int32) + 1 # add padding
192
+ yi = yi.astype(np.int32) + 1 # add padding
193
+ xi = xi.astype(np.int32) + 1 # add padding
194
+ zmed = np.mean(zi)
195
+ ymed = np.mean(yi)
196
+ xmed = np.mean(xi)
197
+ imin = np.argmin((zi - zmed)**2 + (xi - xmed)**2 + (yi - ymed)**2)
198
+ zmed = zi[imin]
199
+ ymed = yi[imin]
200
+ xmed = xi[imin]
201
+ centers[i, 0] = zmed + sz.start
202
+ centers[i, 1] = ymed + sy.start
203
+ centers[i, 2] = xmed + sx.start
204
+
205
+ # get neighbor validator (not all neighbors are in same mask)
206
+ neighbor_masks = masks_padded[tuple(neighbors)]
207
+ isneighbor = neighbor_masks == neighbor_masks[0]
208
+ ext = np.array(
209
+ [[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1]
210
+ for sz, sy, sx in slices])
211
+ n_iter = 6 * (ext.sum(axis=1)).max() if niter is None else niter
212
+
213
+ # run diffusion
214
+ shape = masks_padded.shape
215
+ mu = _extend_centers_gpu(neighbors, centers, isneighbor, shape, n_iter=n_iter,
216
+ device=device)
217
+ # normalize
218
+ mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
219
+
220
+ # put into original image
221
+ mu0 = np.zeros((3, Lz0, Ly0, Lx0))
222
+ mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
223
+ return mu0
224
+
225
+ def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None,
226
+ return_flows=True):
227
+ """Converts labels (list of masks or flows) to flows for training model.
228
+
229
+ Args:
230
+ labels (list of ND-arrays): The labels to convert. labels[k] can be 2D or 3D. If [3 x Ly x Lx],
231
+ it is assumed that flows were precomputed. Otherwise, labels[k][0] or labels[k] (if 2D)
232
+ is used to create flows and cell probabilities.
233
+ files (list of str, optional): The files to save the flows to. If provided, flows are saved to
234
+ files to be reused. Defaults to None.
235
+ device (str, optional): The device to use for computation. Defaults to None.
236
+ redo_flows (bool, optional): Whether to recompute the flows. Defaults to False.
237
+ niter (int, optional): The number of iterations for computing flows. Defaults to None.
238
+
239
+ Returns:
240
+ list of [4 x Ly x Lx] arrays: The flows for training the model. flows[k][0] is labels[k],
241
+ flows[k][1] is cell distance transform, flows[k][2] is Y flow, flows[k][3] is X flow,
242
+ and flows[k][4] is heat distribution.
243
+ """
244
+ nimg = len(labels)
245
+ if labels[0].ndim < 3:
246
+ labels = [labels[n][np.newaxis, :, :] for n in range(nimg)]
247
+
248
+ flows = []
249
+ # flows need to be recomputed
250
+ if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows:
251
+ dynamics_logger.info("computing flows for labels")
252
+
253
+ # compute flows; labels are fixed here to be unique, so they need to be passed back
254
+ # make sure labels are unique!
255
+ labels = [fastremap.renumber(label, in_place=True)[0] for label in labels]
256
+ iterator = trange if nimg > 1 else range
257
+ for n in iterator(nimg):
258
+ labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0]
259
+ vecn = masks_to_flows_gpu(labels[n][0].astype(int), device=device, niter=niter)
260
+
261
+ # concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
262
+ flow = np.concatenate((labels[n], labels[n] > 0.5, vecn),
263
+ axis=0).astype(np.float32)
264
+ if files is not None:
265
+ file_name = os.path.splitext(files[n])[0]
266
+ tifffile.imwrite(file_name + "_flows.tif", flow)
267
+ if return_flows:
268
+ flows.append(flow)
269
+ else:
270
+ dynamics_logger.info("flows precomputed")
271
+ if return_flows:
272
+ flows = [labels[n].astype(np.float32) for n in range(nimg)]
273
+ return flows
274
+
275
+
276
+ def flow_error(maski, dP_net, device=None):
277
+ """Error in flows from predicted masks vs flows predicted by network run on image.
278
+
279
+ This function serves to benchmark the quality of masks. It works as follows:
280
+ 1. The predicted masks are used to create a flow diagram.
281
+ 2. The mask-flows are compared to the flows that the network predicted.
282
+
283
+ If there is a discrepancy between the flows, it suggests that the mask is incorrect.
284
+ Masks with flow_errors greater than 0.4 are discarded by default. This setting can be
285
+ changed in Cellpose.eval or CellposeModel.eval.
286
+
287
+ Args:
288
+ maski (np.ndarray, int): Masks produced from running dynamics on dP_net, where 0=NO masks; 1,2... are mask labels.
289
+ dP_net (np.ndarray, float): ND flows where dP_net.shape[1:] = maski.shape.
290
+
291
+ Returns:
292
+ A tuple containing (flow_errors, dP_masks): flow_errors (np.ndarray, float): Mean squared error between predicted flows and flows from masks;
293
+ dP_masks (np.ndarray, float): ND flows produced from the predicted masks.
294
+ """
295
+ if dP_net.shape[1:] != maski.shape:
296
+ print("ERROR: net flow is not same size as predicted masks")
297
+ return
298
+
299
+ # flows predicted from estimated masks
300
+ dP_masks = masks_to_flows_gpu(maski, device=device)
301
+ # difference between predicted flows vs mask flows
302
+ flow_errors = np.zeros(maski.max())
303
+ for i in range(dP_masks.shape[0]):
304
+ flow_errors += mean((dP_masks[i] - dP_net[i] / 5.)**2, maski,
305
+ index=np.arange(1,
306
+ maski.max() + 1))
307
+
308
+ return flow_errors, dP_masks
309
+
310
+
311
+ def steps_interp(dP, inds, niter, device=torch.device("cpu")):
312
+ """ Run dynamics of pixels to recover masks in 2D/3D, with interpolation between pixel values.
313
+
314
+ Euler integration of dynamics dP for niter steps.
315
+
316
+ Args:
317
+ p (numpy.ndarray): Array of shape (n_points, 2 or 3) representing the initial pixel locations.
318
+ dP (numpy.ndarray): Array of shape (2, Ly, Lx) or (3, Lz, Ly, Lx) representing the flow field.
319
+ niter (int): Number of iterations to perform.
320
+ device (torch.device, optional): Device to use for computation. Defaults to None.
321
+
322
+ Returns:
323
+ numpy.ndarray: Array of shape (n_points, 2) or (n_points, 3) representing the final pixel locations.
324
+
325
+ Raises:
326
+ None
327
+
328
+ """
329
+
330
+ shape = dP.shape[1:]
331
+ ndim = len(shape)
332
+
333
+ pt = torch.zeros((*[1]*ndim, len(inds[0]), ndim), dtype=torch.float32, device=device)
334
+ im = torch.zeros((1, ndim, *shape), dtype=torch.float32, device=device)
335
+ # Y and X dimensions, flipped X-1, Y-1
336
+ # pt is [1 1 1 3 n_points]
337
+ for n in range(ndim):
338
+ if ndim==3:
339
+ pt[0, 0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32)
340
+ else:
341
+ pt[0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32)
342
+ im[0, ndim - n - 1] = torch.from_numpy(dP[n]).to(device, dtype=torch.float32)
343
+ shape = np.array(shape)[::-1].astype("float") - 1
344
+
345
+ # normalize pt between 0 and 1, normalize the flow
346
+ for k in range(ndim):
347
+ im[:, k] *= 2. / shape[k]
348
+ pt[..., k] /= shape[k]
349
+
350
+ # normalize to between -1 and 1
351
+ pt *= 2
352
+ pt -= 1
353
+
354
+ # dynamics
355
+ for t in range(niter):
356
+ dPt = torch.nn.functional.grid_sample(im, pt, align_corners=False)
357
+ for k in range(ndim): #clamp the final pixel locations
358
+ pt[..., k] = torch.clamp(pt[..., k] + dPt[:, k], -1., 1.)
359
+
360
+ #undo the normalization from before, reverse order of operations
361
+ pt += 1
362
+ pt *= 0.5
363
+ for k in range(ndim):
364
+ pt[..., k] *= shape[k]
365
+
366
+ if ndim==3:
367
+ pt = pt[..., [2, 1, 0]].squeeze()
368
+ pt = pt.unsqueeze(0) if pt.ndim==1 else pt
369
+ return pt.T
370
+ else:
371
+ pt = pt[..., [1, 0]].squeeze()
372
+ pt = pt.unsqueeze(0) if pt.ndim==1 else pt
373
+ return pt.T
374
+
375
+ def follow_flows(dP, inds, niter=200, device=torch.device("cpu")):
376
+ """ Run dynamics to recover masks in 2D or 3D.
377
+
378
+ Pixels are represented as a meshgrid. Only pixels with non-zero cell-probability
379
+ are used (as defined by inds).
380
+
381
+ Args:
382
+ dP (np.ndarray): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
383
+ mask (np.ndarray, optional): Pixel mask to seed masks. Useful when flows have low magnitudes.
384
+ niter (int, optional): Number of iterations of dynamics to run. Default is 200.
385
+ interp (bool, optional): Interpolate during 2D dynamics (not available in 3D). Default is True.
386
+ device (torch.device, optional): Device to use for computation. Default is None.
387
+
388
+ Returns:
389
+ A tuple containing (p, inds): p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx];
390
+ inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
391
+ """
392
+ shape = np.array(dP.shape[1:]).astype(np.int32)
393
+ ndim = len(inds)
394
+
395
+ p = steps_interp(dP, inds, niter, device=device)
396
+
397
+ return p
398
+
399
+
400
+ def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")):
401
+ """Remove masks which have inconsistent flows.
402
+
403
+ Uses metrics.flow_error to compute flows from predicted masks
404
+ and compare flows to predicted flows from the network. Discards
405
+ masks with flow errors greater than the threshold.
406
+
407
+ Args:
408
+ masks (int, 2D or 3D array): Labelled masks, 0=NO masks; 1,2,...=mask labels,
409
+ size [Ly x Lx] or [Lz x Ly x Lx].
410
+ flows (float, 3D or 4D array): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
411
+ threshold (float, optional): Masks with flow error greater than threshold are discarded.
412
+ Default is 0.4.
413
+
414
+ Returns:
415
+ masks (int, 2D or 3D array): Masks with inconsistent flow masks removed,
416
+ 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
417
+ """
418
+ device0 = device
419
+ if masks.size > 10000 * 10000 and (device is not None and device.type == "cuda"):
420
+
421
+ major_version, minor_version = torch.__version__.split(".")[:2]
422
+ torch.cuda.empty_cache()
423
+ if major_version == "1" and int(minor_version) < 10:
424
+ # for PyTorch version lower than 1.10
425
+ def mem_info():
426
+ total_mem = torch.cuda.get_device_properties(device0.index).total_memory
427
+ used_mem = torch.cuda.memory_allocated(device0.index)
428
+ free_mem = total_mem - used_mem
429
+ return total_mem, free_mem
430
+ else:
431
+ # for PyTorch version 1.10 and above
432
+ def mem_info():
433
+ free_mem, total_mem = torch.cuda.mem_get_info(device0.index)
434
+ return total_mem, free_mem
435
+ total_mem, free_mem = mem_info()
436
+ if masks.size * 32 > free_mem:
437
+ dynamics_logger.warning(
438
+ "WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold"
439
+ )
440
+ dynamics_logger.info("turn off QC step with flow_threshold=0 if too slow")
441
+ device0 = torch.device("cpu")
442
+
443
+ merrors, _ = flow_error(masks, flows, device0)
444
+ badi = 1 + (merrors > threshold).nonzero()[0]
445
+ masks[np.isin(masks, badi)] = 0
446
+ return masks
447
+
448
+
449
+ def max_pool1d(h, kernel_size=5, axis=1, out=None):
450
+ """ memory efficient max_pool thanks to Mark Kittisopikul
451
+
452
+ for stride=1, padding=kernel_size//2, requires odd kernel_size >= 3
453
+
454
+ """
455
+ if out is None:
456
+ out = h.clone()
457
+ else:
458
+ out.copy_(h)
459
+
460
+ nd = h.shape[axis]
461
+ k0 = kernel_size // 2
462
+ for d in range(-k0, k0+1):
463
+ if axis==1:
464
+ mv = out[:, max(-d,0):min(nd-d,nd)]
465
+ hv = h[:, max(d,0):min(nd+d,nd)]
466
+ elif axis==2:
467
+ mv = out[:, :, max(-d,0):min(nd-d,nd)]
468
+ hv = h[:, :, max(d,0):min(nd+d,nd)]
469
+ elif axis==3:
470
+ mv = out[:, :, :, max(-d,0):min(nd-d,nd)]
471
+ hv = h[:, :, :, max(d,0):min(nd+d,nd)]
472
+ torch.maximum(mv, hv, out=mv)
473
+ return out
474
+
475
+ def max_pool_nd(h, kernel_size=5):
476
+ """ memory efficient max_pool in 2d or 3d """
477
+ ndim = h.ndim - 1
478
+ hmax = max_pool1d(h, kernel_size=kernel_size, axis=1)
479
+ hmax2 = max_pool1d(hmax, kernel_size=kernel_size, axis=2)
480
+ if ndim==2:
481
+ del hmax
482
+ return hmax2
483
+ else:
484
+ hmax = max_pool1d(hmax2, kernel_size=kernel_size, axis=3, out=hmax)
485
+ del hmax2
486
+ return hmax
487
+
488
+ def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
489
+ """Create masks using pixel convergence after running dynamics.
490
+
491
+ Makes a histogram of final pixel locations p, initializes masks
492
+ at peaks of histogram and extends the masks from the peaks so that
493
+ they include all pixels with more than 2 final pixels p. Discards
494
+ masks with flow errors greater than the threshold.
495
+
496
+ Parameters:
497
+ p (float32, 3D or 4D array): Final locations of each pixel after dynamics,
498
+ size [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
499
+ iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are
500
+ iscell False to stay in their original location.
501
+ rpad (int, optional): Histogram edge padding. Default is 20.
502
+ max_size_fraction (float, optional): Masks larger than max_size_fraction of
503
+ total image size are removed. Default is 0.4.
504
+
505
+ Returns:
506
+ M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed,
507
+ 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
508
+ """
509
+
510
+ ndim = len(shape0)
511
+ device = pt.device
512
+
513
+ rpad = 20
514
+ pt += rpad
515
+ pt = torch.clamp(pt, min=0)
516
+ for i in range(len(pt)):
517
+ pt[i] = torch.clamp(pt[i], max=shape0[i]+rpad-1)
518
+
519
+ # # add extra padding to make divisible by 5
520
+ # shape = tuple((np.ceil((shape0 + 2*rpad)/5) * 5).astype(int))
521
+ shape = tuple(np.array(shape0) + 2*rpad)
522
+
523
+ # sparse coo torch
524
+ coo = torch.sparse_coo_tensor(pt, torch.ones(pt.shape[1], device=pt.device, dtype=torch.int),
525
+ shape)
526
+ h1 = coo.to_dense()
527
+ del coo
528
+
529
+ hmax1 = max_pool_nd(h1.unsqueeze(0), kernel_size=5)
530
+ hmax1 = hmax1.squeeze()
531
+ seeds1 = torch.nonzero((h1 - hmax1 > -1e-6) * (h1 > 10))
532
+ del hmax1
533
+ if len(seeds1) == 0:
534
+ dynamics_logger.warning("no seeds found in get_masks_torch - no masks found.")
535
+ return np.zeros(shape0, dtype="uint16")
536
+
537
+ npts = h1[tuple(seeds1.T)]
538
+ isort1 = npts.argsort()
539
+ seeds1 = seeds1[isort1]
540
+
541
+ n_seeds = len(seeds1)
542
+ h_slc = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
543
+ for k in range(n_seeds):
544
+ slc = tuple([slice(seeds1[k][j]-5, seeds1[k][j]+6) for j in range(ndim)])
545
+ h_slc[k] = h1[slc]
546
+ del h1
547
+ seed_masks = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
548
+ if ndim==2:
549
+ seed_masks[:,5,5] = 1
550
+ else:
551
+ seed_masks[:,5,5,5] = 1
552
+
553
+ for iter in range(5):
554
+ # extend
555
+ seed_masks = max_pool_nd(seed_masks, kernel_size=3)
556
+ seed_masks *= h_slc > 2
557
+ del h_slc
558
+ seeds_new = [tuple((torch.nonzero(seed_masks[k]) + seeds1[k] - 5).T)
559
+ for k in range(n_seeds)]
560
+ del seed_masks
561
+
562
+ dtype = torch.int32 if n_seeds < 2**16 else torch.int64
563
+ M1 = torch.zeros(shape, dtype=dtype, device=device)
564
+ for k in range(n_seeds):
565
+ M1[seeds_new[k]] = 1 + k
566
+
567
+ M1 = M1[tuple(pt)]
568
+ M1 = M1.cpu().numpy()
569
+
570
+ dtype = "uint16" if n_seeds < 2**16 else "uint32"
571
+ M0 = np.zeros(shape0, dtype=dtype)
572
+ M0[inds] = M1
573
+
574
+ # remove big masks
575
+ uniq, counts = fastremap.unique(M0, return_counts=True)
576
+ big = np.prod(shape0) * max_size_fraction
577
+ bigc = uniq[counts > big]
578
+ if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0):
579
+ M0 = fastremap.mask(M0, bigc)
580
+ fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
581
+ M0 = M0.reshape(tuple(shape0))
582
+
583
+ #print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9:.3f} gb")
584
+ return M0
585
+
586
+
587
+ def resize_and_compute_masks(dP, cellprob, niter=200, cellprob_threshold=0.0,
588
+ flow_threshold=0.4, do_3D=False, min_size=15,
589
+ max_size_fraction=0.4, resize=None, device=torch.device("cpu")):
590
+ """Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None.
591
+
592
+ Args:
593
+ dP (numpy.ndarray): The dynamics flow field array.
594
+ cellprob (numpy.ndarray): The cell probability array.
595
+ p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None
596
+ niter (int, optional): The number of iterations for mask computation. Defaults to 200.
597
+ cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0.
598
+ flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4.
599
+ interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
600
+ do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
601
+ min_size (int, optional): The minimum size of the masks. Defaults to 15.
602
+ max_size_fraction (float, optional): Masks larger than max_size_fraction of
603
+ total image size are removed. Default is 0.4.
604
+ resize (tuple, optional): The desired size for resizing the masks. Defaults to None.
605
+ device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu").
606
+
607
+ Returns:
608
+ tuple: A tuple containing the computed masks and the final pixel locations.
609
+ """
610
+ mask = compute_masks(dP, cellprob, niter=niter,
611
+ cellprob_threshold=cellprob_threshold,
612
+ flow_threshold=flow_threshold, do_3D=do_3D,
613
+ max_size_fraction=max_size_fraction,
614
+ device=device)
615
+
616
+ if resize is not None:
617
+ dynamics_logger.warning("Resizing is depricated in v4.0.1+")
618
+
619
+ mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
620
+
621
+ return mask
622
+
623
+
624
+ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
625
+ flow_threshold=0.4, do_3D=False, min_size=-1,
626
+ max_size_fraction=0.4, device=torch.device("cpu")):
627
+ """Compute masks using dynamics from dP and cellprob.
628
+
629
+ Args:
630
+ dP (numpy.ndarray): The dynamics flow field array.
631
+ cellprob (numpy.ndarray): The cell probability array.
632
+ p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None
633
+ niter (int, optional): The number of iterations for mask computation. Defaults to 200.
634
+ cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0.
635
+ flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4.
636
+ interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
637
+ do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
638
+ min_size (int, optional): The minimum size of the masks. Defaults to 15.
639
+ max_size_fraction (float, optional): Masks larger than max_size_fraction of
640
+ total image size are removed. Default is 0.4.
641
+ device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu").
642
+
643
+ Returns:
644
+ tuple: A tuple containing the computed masks and the final pixel locations.
645
+ """
646
+
647
+ if (cellprob > cellprob_threshold).sum(): #mask at this point is a cell cluster binary map, not labels
648
+ inds = np.nonzero(cellprob > cellprob_threshold)
649
+ if len(inds[0]) == 0:
650
+ dynamics_logger.info("No cell pixels found.")
651
+ shape = cellprob.shape
652
+ mask = np.zeros(shape, "uint16")
653
+ return mask
654
+
655
+ p_final = follow_flows(dP * (cellprob > cellprob_threshold) / 5.,
656
+ inds=inds, niter=niter,
657
+ device=device)
658
+ if not torch.is_tensor(p_final):
659
+ p_final = torch.from_numpy(p_final).to(device, dtype=torch.int)
660
+ else:
661
+ p_final = p_final.int()
662
+ # calculate masks
663
+ if device.type == "mps":
664
+ p_final = p_final.to(torch.device("cpu"))
665
+ mask = get_masks_torch(p_final, inds, dP.shape[1:],
666
+ max_size_fraction=max_size_fraction)
667
+ del p_final
668
+ # flow thresholding factored out of get_masks
669
+ if not do_3D:
670
+ if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0:
671
+ # make sure labels are unique at output of get_masks
672
+ mask = remove_bad_flow_masks(mask, dP, threshold=flow_threshold,
673
+ device=device)
674
+
675
+ if mask.max() < 2**16 and mask.dtype != "uint16":
676
+ mask = mask.astype("uint16")
677
+
678
+ else: # nothing to compute, just make it compatible
679
+ dynamics_logger.info("No cell pixels found.")
680
+ shape = cellprob.shape
681
+ mask = np.zeros(cellprob.shape, "uint16")
682
+ return mask
683
+
684
+ if min_size > 0:
685
+ mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
686
+
687
+ if mask.dtype == np.uint32:
688
+ dynamics_logger.warning(
689
+ "more than 65535 masks in image, masks returned as np.uint32")
690
+
691
+ return mask
models/seg_post_model/cellpose/export.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Auxiliary module for bioimageio format export
2
+
3
+ Example usage:
4
+
5
+ ```bash
6
+ #!/bin/bash
7
+
8
+ # Define default paths and parameters
9
+ DEFAULT_CHANNELS="1 0"
10
+ DEFAULT_PATH_PRETRAINED_MODEL="/home/qinyu/models/cp/cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995"
11
+ DEFAULT_PATH_README="/home/qinyu/models/cp/README.md"
12
+ DEFAULT_LIST_PATH_COVER_IMAGES="/home/qinyu/images/cp/cellpose_raw_and_segmentation.jpg /home/qinyu/images/cp/cellpose_raw_and_probability.jpg /home/qinyu/images/cp/cellpose_raw.jpg"
13
+ DEFAULT_MODEL_ID="philosophical-panda"
14
+ DEFAULT_MODEL_ICON="🐼"
15
+ DEFAULT_MODEL_VERSION="0.1.0"
16
+ DEFAULT_MODEL_NAME="My Cool Cellpose"
17
+ DEFAULT_MODEL_DOCUMENTATION="A cool Cellpose model trained for my cool dataset."
18
+ DEFAULT_MODEL_AUTHORS='[{"name": "Qin Yu", "affiliation": "EMBL", "github_user": "qin-yu", "orcid": "0000-0002-4652-0795"}]'
19
+ DEFAULT_MODEL_CITE='[{"text": "For more details of the model itself, see the manuscript", "doi": "10.1242/dev.202800", "url": null}]'
20
+ DEFAULT_MODEL_TAGS="cellpose 3d 2d"
21
+ DEFAULT_MODEL_LICENSE="MIT"
22
+ DEFAULT_MODEL_REPO="https://github.com/kreshuklab/go-nuclear"
23
+
24
+ # Run the Python script with default parameters
25
+ python export.py \
26
+ --channels $DEFAULT_CHANNELS \
27
+ --path_pretrained_model "$DEFAULT_PATH_PRETRAINED_MODEL" \
28
+ --path_readme "$DEFAULT_PATH_README" \
29
+ --list_path_cover_images $DEFAULT_LIST_PATH_COVER_IMAGES \
30
+ --model_version "$DEFAULT_MODEL_VERSION" \
31
+ --model_name "$DEFAULT_MODEL_NAME" \
32
+ --model_documentation "$DEFAULT_MODEL_DOCUMENTATION" \
33
+ --model_authors "$DEFAULT_MODEL_AUTHORS" \
34
+ --model_cite "$DEFAULT_MODEL_CITE" \
35
+ --model_tags $DEFAULT_MODEL_TAGS \
36
+ --model_license "$DEFAULT_MODEL_LICENSE" \
37
+ --model_repo "$DEFAULT_MODEL_REPO"
38
+ ```
39
+ """
40
+
41
+ import os
42
+ import sys
43
+ import json
44
+ import argparse
45
+ from pathlib import Path
46
+ from urllib.parse import urlparse
47
+
48
+ import torch
49
+ import numpy as np
50
+
51
+ from cellpose.io import imread
52
+ from cellpose.utils import download_url_to_file
53
+ from cellpose.transforms import pad_image_ND, normalize_img, convert_image
54
+ from cellpose.vit_sam import CPnetBioImageIO
55
+
56
+ from bioimageio.spec.model.v0_5 import (
57
+ ArchitectureFromFileDescr,
58
+ Author,
59
+ AxisId,
60
+ ChannelAxis,
61
+ CiteEntry,
62
+ Doi,
63
+ FileDescr,
64
+ Identifier,
65
+ InputTensorDescr,
66
+ IntervalOrRatioDataDescr,
67
+ LicenseId,
68
+ ModelDescr,
69
+ ModelId,
70
+ OrcidId,
71
+ OutputTensorDescr,
72
+ ParameterizedSize,
73
+ PytorchStateDictWeightsDescr,
74
+ SizeReference,
75
+ SpaceInputAxis,
76
+ SpaceOutputAxis,
77
+ TensorId,
78
+ TorchscriptWeightsDescr,
79
+ Version,
80
+ WeightsDescr,
81
+ )
82
+ # Define ARBITRARY_SIZE if it is not available in the module
83
+ try:
84
+ from bioimageio.spec.model.v0_5 import ARBITRARY_SIZE
85
+ except ImportError:
86
+ ARBITRARY_SIZE = ParameterizedSize(min=1, step=1)
87
+
88
+ from bioimageio.spec.common import HttpUrl
89
+ from bioimageio.spec import save_bioimageio_package
90
+ from bioimageio.core import test_model
91
+
92
+ DEFAULT_CHANNELS = [2, 1]
93
+ DEFAULT_NORMALIZE_PARAMS = {
94
+ "axis": -1,
95
+ "lowhigh": None,
96
+ "percentile": None,
97
+ "normalize": True,
98
+ "norm3D": False,
99
+ "sharpen_radius": 0,
100
+ "smooth_radius": 0,
101
+ "tile_norm_blocksize": 0,
102
+ "tile_norm_smooth3D": 1,
103
+ "invert": False,
104
+ }
105
+ IMAGE_URL = "http://www.cellpose.org/static/data/rgb_3D.tif"
106
+
107
+
108
+ def download_and_normalize_image(path_dir_temp, channels=DEFAULT_CHANNELS):
109
+ """
110
+ Download and normalize image.
111
+ """
112
+ filename = os.path.basename(urlparse(IMAGE_URL).path)
113
+ path_image = path_dir_temp / filename
114
+ if not path_image.exists():
115
+ sys.stderr.write(f'Downloading: "{IMAGE_URL}" to {path_image}\n')
116
+ download_url_to_file(IMAGE_URL, path_image)
117
+ img = imread(path_image).astype(np.float32)
118
+ img = convert_image(img, channels, channel_axis=1, z_axis=0, do_3D=False, nchan=2)
119
+ img = normalize_img(img, **DEFAULT_NORMALIZE_PARAMS)
120
+ img = np.transpose(img, (0, 3, 1, 2))
121
+ img, _, _ = pad_image_ND(img)
122
+ return img
123
+
124
+
125
+ def load_bioimageio_cpnet_model(path_model_weight, nchan=2):
126
+ cpnet_kwargs = {
127
+ "nout": 3,
128
+ }
129
+ cpnet_biio = CPnetBioImageIO(**cpnet_kwargs)
130
+ state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True)
131
+ cpnet_biio.load_state_dict(state_dict_cuda)
132
+ cpnet_biio.eval() # crucial for the prediction results
133
+ return cpnet_biio, cpnet_kwargs
134
+
135
+
136
+ def descr_gen_input(path_test_input, nchan=2):
137
+ input_axes = [
138
+ SpaceInputAxis(id=AxisId("z"), size=ARBITRARY_SIZE),
139
+ ChannelAxis(channel_names=[Identifier(f"c{i+1}") for i in range(nchan)]),
140
+ SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=16, step=16)),
141
+ SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=16, step=16)),
142
+ ]
143
+ data_descr = IntervalOrRatioDataDescr(type="float32")
144
+ path_test_input = Path(path_test_input)
145
+ descr_input = InputTensorDescr(
146
+ id=TensorId("raw"),
147
+ axes=input_axes,
148
+ test_tensor=FileDescr(source=path_test_input),
149
+ data=data_descr,
150
+ )
151
+ return descr_input
152
+
153
+
154
+ def descr_gen_output_flow(path_test_output):
155
+ output_axes_output_tensor = [
156
+ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
157
+ ChannelAxis(channel_names=[Identifier("flow1"), Identifier("flow2"), Identifier("flow3")]),
158
+ SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))),
159
+ SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x"))),
160
+ ]
161
+ path_test_output = Path(path_test_output)
162
+ descr_output = OutputTensorDescr(
163
+ id=TensorId("flow"),
164
+ axes=output_axes_output_tensor,
165
+ test_tensor=FileDescr(source=path_test_output),
166
+ )
167
+ return descr_output
168
+
169
+
170
+ def descr_gen_output_downsampled(path_dir_temp, nbase=None):
171
+ if nbase is None:
172
+ nbase = [32, 64, 128, 256]
173
+
174
+ output_axes_downsampled_tensors = [
175
+ [
176
+ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
177
+ ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(base)]),
178
+ SpaceOutputAxis(
179
+ id=AxisId("y"),
180
+ size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y")),
181
+ scale=2**offset,
182
+ ),
183
+ SpaceOutputAxis(
184
+ id=AxisId("x"),
185
+ size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")),
186
+ scale=2**offset,
187
+ ),
188
+ ]
189
+ for offset, base in enumerate(nbase)
190
+ ]
191
+ path_downsampled_tensors = [
192
+ Path(path_dir_temp / f"test_downsampled_{i}.npy") for i in range(len(output_axes_downsampled_tensors))
193
+ ]
194
+ descr_output_downsampled_tensors = [
195
+ OutputTensorDescr(
196
+ id=TensorId(f"downsampled_{i}"),
197
+ axes=axes,
198
+ test_tensor=FileDescr(source=path),
199
+ )
200
+ for i, (axes, path) in enumerate(zip(output_axes_downsampled_tensors, path_downsampled_tensors))
201
+ ]
202
+ return descr_output_downsampled_tensors
203
+
204
+
205
+ def descr_gen_output_style(path_test_style, nchannel=256):
206
+ output_axes_style_tensor = [
207
+ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
208
+ ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(nchannel)]),
209
+ ]
210
+ path_style_tensor = Path(path_test_style)
211
+ descr_output_style_tensor = OutputTensorDescr(
212
+ id=TensorId("style"),
213
+ axes=output_axes_style_tensor,
214
+ test_tensor=FileDescr(source=path_style_tensor),
215
+ )
216
+ return descr_output_style_tensor
217
+
218
+
219
+ def descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper=None):
220
+ if path_cpnet_wrapper is None:
221
+ path_cpnet_wrapper = Path(__file__).parent / "resnet_torch.py"
222
+ pytorch_architecture = ArchitectureFromFileDescr(
223
+ callable=Identifier("CPnetBioImageIO"),
224
+ source=Path(path_cpnet_wrapper),
225
+ kwargs=cpnet_kwargs,
226
+ )
227
+ return pytorch_architecture
228
+
229
+
230
+ def descr_gen_documentation(path_doc, markdown_text):
231
+ with open(path_doc, "w") as f:
232
+ f.write(markdown_text)
233
+
234
+
235
+ def package_to_bioimageio(
236
+ path_pretrained_model,
237
+ path_save_trace,
238
+ path_readme,
239
+ list_path_cover_images,
240
+ descr_input,
241
+ descr_output,
242
+ descr_output_downsampled_tensors,
243
+ descr_output_style_tensor,
244
+ pytorch_version,
245
+ pytorch_architecture,
246
+ model_id,
247
+ model_icon,
248
+ model_version,
249
+ model_name,
250
+ model_documentation,
251
+ model_authors,
252
+ model_cite,
253
+ model_tags,
254
+ model_license,
255
+ model_repo,
256
+ ):
257
+ """Package model description to BioImage.IO format."""
258
+ my_model_descr = ModelDescr(
259
+ id=ModelId(model_id) if model_id is not None else None,
260
+ id_emoji=model_icon,
261
+ version=Version(model_version),
262
+ name=model_name,
263
+ description=model_documentation,
264
+ authors=[
265
+ Author(
266
+ name=author["name"],
267
+ affiliation=author["affiliation"],
268
+ github_user=author["github_user"],
269
+ orcid=OrcidId(author["orcid"]),
270
+ )
271
+ for author in model_authors
272
+ ],
273
+ cite=[CiteEntry(text=cite["text"], doi=Doi(cite["doi"]), url=cite["url"]) for cite in model_cite],
274
+ covers=[Path(img) for img in list_path_cover_images],
275
+ license=LicenseId(model_license),
276
+ tags=model_tags,
277
+ documentation=Path(path_readme),
278
+ git_repo=HttpUrl(model_repo),
279
+ inputs=[descr_input],
280
+ outputs=[descr_output, descr_output_style_tensor] + descr_output_downsampled_tensors,
281
+ weights=WeightsDescr(
282
+ pytorch_state_dict=PytorchStateDictWeightsDescr(
283
+ source=Path(path_pretrained_model),
284
+ architecture=pytorch_architecture,
285
+ pytorch_version=pytorch_version,
286
+ ),
287
+ torchscript=TorchscriptWeightsDescr(
288
+ source=Path(path_save_trace),
289
+ pytorch_version=pytorch_version,
290
+ parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights.
291
+ ),
292
+ ),
293
+ )
294
+
295
+ return my_model_descr
296
+
297
+
298
+ def parse_args():
299
+ # fmt: off
300
+ parser = argparse.ArgumentParser(description="BioImage.IO model packaging for Cellpose")
301
+ parser.add_argument("--channels", nargs=2, default=[2, 1], type=int, help="Cyto-only = [2, 0], Cyto + Nuclei = [2, 1], Nuclei-only = [1, 0]")
302
+ parser.add_argument("--path_pretrained_model", required=True, type=str, help="Path to pretrained model file, e.g., cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995")
303
+ parser.add_argument("--path_readme", required=True, type=str, help="Path to README file")
304
+ parser.add_argument("--list_path_cover_images", nargs='+', required=True, type=str, help="List of paths to cover images")
305
+ parser.add_argument("--model_id", type=str, help="Model ID, provide if already exists", default=None)
306
+ parser.add_argument("--model_icon", type=str, help="Model icon, provide if already exists", default=None)
307
+ parser.add_argument("--model_version", required=True, type=str, help="Model version, new model should be 0.1.0")
308
+ parser.add_argument("--model_name", required=True, type=str, help="Model name, e.g., My Cool Cellpose")
309
+ parser.add_argument("--model_documentation", required=True, type=str, help="Model documentation, e.g., A cool Cellpose model trained for my cool dataset.")
310
+ parser.add_argument("--model_authors", required=True, type=str, help="Model authors in JSON format, e.g., '[{\"name\": \"Qin Yu\", \"affiliation\": \"EMBL\", \"github_user\": \"qin-yu\", \"orcid\": \"0000-0002-4652-0795\"}]'")
311
+ parser.add_argument("--model_cite", required=True, type=str, help="Model citation in JSON format, e.g., '[{\"text\": \"For more details of the model itself, see the manuscript\", \"doi\": \"10.1242/dev.202800\", \"url\": null}]'")
312
+ parser.add_argument("--model_tags", nargs='+', required=True, type=str, help="Model tags, e.g., cellpose 3d 2d")
313
+ parser.add_argument("--model_license", required=True, type=str, help="Model license, e.g., MIT")
314
+ parser.add_argument("--model_repo", required=True, type=str, help="Model repository URL")
315
+ return parser.parse_args()
316
+ # fmt: on
317
+
318
+
319
+ def main():
320
+ args = parse_args()
321
+
322
+ # Parse user-provided paths and arguments
323
+ channels = args.channels
324
+ model_cite = json.loads(args.model_cite)
325
+ model_authors = json.loads(args.model_authors)
326
+
327
+ path_readme = Path(args.path_readme)
328
+ path_pretrained_model = Path(args.path_pretrained_model)
329
+ list_path_cover_images = [Path(path_image) for path_image in args.list_path_cover_images]
330
+
331
+ # Auto-generated paths
332
+ path_cpnet_wrapper = Path(__file__).resolve().parent / "resnet_torch.py"
333
+ path_dir_temp = Path(__file__).resolve().parent.parent / "models" / path_pretrained_model.stem
334
+ path_dir_temp.mkdir(parents=True, exist_ok=True)
335
+
336
+ path_save_trace = path_dir_temp / "cp_traced.pt"
337
+ path_test_input = path_dir_temp / "test_input.npy"
338
+ path_test_output = path_dir_temp / "test_output.npy"
339
+ path_test_style = path_dir_temp / "test_style.npy"
340
+ path_bioimageio_package = path_dir_temp / "cellpose_model.zip"
341
+
342
+ # Download test input image
343
+ img_np = download_and_normalize_image(path_dir_temp, channels=channels)
344
+ np.save(path_test_input, img_np)
345
+ img = torch.tensor(img_np).float()
346
+
347
+ # Load model
348
+ cpnet_biio, cpnet_kwargs = load_bioimageio_cpnet_model(path_pretrained_model)
349
+
350
+ # Test model and save output
351
+ tuple_output_tensor = cpnet_biio(img)
352
+ np.save(path_test_output, tuple_output_tensor[0].detach().numpy())
353
+ np.save(path_test_style, tuple_output_tensor[1].detach().numpy())
354
+ for i, t in enumerate(tuple_output_tensor[2:]):
355
+ np.save(path_dir_temp / f"test_downsampled_{i}.npy", t.detach().numpy())
356
+
357
+ # Save traced model
358
+ model_traced = torch.jit.trace(cpnet_biio, img)
359
+ model_traced.save(path_save_trace)
360
+
361
+ # Generate model description
362
+ descr_input = descr_gen_input(path_test_input)
363
+ descr_output = descr_gen_output_flow(path_test_output)
364
+ descr_output_downsampled_tensors = descr_gen_output_downsampled(path_dir_temp, nbase=cpnet_biio.nbase[1:])
365
+ descr_output_style_tensor = descr_gen_output_style(path_test_style, cpnet_biio.nbase[-1])
366
+ pytorch_version = Version(torch.__version__)
367
+ pytorch_architecture = descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper)
368
+
369
+ # Package model
370
+ my_model_descr = package_to_bioimageio(
371
+ path_pretrained_model,
372
+ path_save_trace,
373
+ path_readme,
374
+ list_path_cover_images,
375
+ descr_input,
376
+ descr_output,
377
+ descr_output_downsampled_tensors,
378
+ descr_output_style_tensor,
379
+ pytorch_version,
380
+ pytorch_architecture,
381
+ args.model_id,
382
+ args.model_icon,
383
+ args.model_version,
384
+ args.model_name,
385
+ args.model_documentation,
386
+ model_authors,
387
+ model_cite,
388
+ args.model_tags,
389
+ args.model_license,
390
+ args.model_repo,
391
+ )
392
+
393
+ # Test model
394
+ summary = test_model(my_model_descr, weight_format="pytorch_state_dict")
395
+ summary.display()
396
+ summary = test_model(my_model_descr, weight_format="torchscript")
397
+ summary.display()
398
+
399
+ # Save BioImage.IO package
400
+ package_path = save_bioimageio_package(my_model_descr, output_path=Path(path_bioimageio_package))
401
+ print("package path:", package_path)
402
+
403
+
404
+ if __name__ == "__main__":
405
+ main()
models/seg_post_model/cellpose/gui/gui.py ADDED
@@ -0,0 +1,2007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
3
+ """
4
+
5
+ import sys, os, pathlib, warnings, datetime, time, copy
6
+
7
+ from qtpy import QtGui, QtCore
8
+ from superqt import QRangeSlider, QCollapsible
9
+ from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, \
10
+ QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, \
11
+ QLineEdit, QMessageBox, QGroupBox, QMenu, QAction
12
+ import pyqtgraph as pg
13
+
14
+ import numpy as np
15
+ from scipy.stats import mode
16
+ import cv2
17
+
18
+ from . import guiparts, menus, io
19
+ from .. import models, core, dynamics, version, train
20
+ from ..utils import download_url_to_file, masks_to_outlines, diameters
21
+ from ..io import get_image_files, imsave, imread
22
+ from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img
23
+ from ..models import normalize_default
24
+ from ..plot import disk
25
+
26
+ try:
27
+ import matplotlib.pyplot as plt
28
+ MATPLOTLIB = True
29
+ except:
30
+ MATPLOTLIB = False
31
+
32
+ Horizontal = QtCore.Qt.Orientation.Horizontal
33
+
34
+
35
+ class Slider(QRangeSlider):
36
+
37
+ def __init__(self, parent, name, color):
38
+ super().__init__(Horizontal)
39
+ self.setEnabled(False)
40
+ self.valueChanged.connect(lambda: self.levelChanged(parent))
41
+ self.name = name
42
+
43
+ self.setStyleSheet(""" QSlider{
44
+ background-color: transparent;
45
+ }
46
+ """)
47
+ self.show()
48
+
49
+ def levelChanged(self, parent):
50
+ parent.level_change(self.name)
51
+
52
+
53
+ class QHLine(QFrame):
54
+
55
+ def __init__(self):
56
+ super(QHLine, self).__init__()
57
+ self.setFrameShape(QFrame.HLine)
58
+ self.setLineWidth(8)
59
+
60
+
61
+ def make_bwr():
62
+ # make a bwr colormap
63
+ b = np.append(255 * np.ones(128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
64
+ r = np.append(np.linspace(0, 255, 128), 255 * np.ones(128))[:, np.newaxis]
65
+ g = np.append(np.linspace(0, 255, 128),
66
+ np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
67
+ color = np.concatenate((r, g, b), axis=-1).astype(np.uint8)
68
+ bwr = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
69
+ return bwr
70
+
71
+
72
+ def make_spectral():
73
+ # make spectral colormap
74
+ r = np.array([
75
+ 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80,
76
+ 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 128, 128, 128, 128, 128,
77
+ 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 120, 112, 104, 96, 88,
78
+ 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
79
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 7, 11, 15, 19, 23,
80
+ 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103,
81
+ 107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167,
82
+ 171, 175, 179, 183, 187, 191, 195, 199, 203, 207, 211, 215, 219, 223, 227, 231,
83
+ 235, 239, 243, 247, 251, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
84
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
85
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
86
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
87
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
88
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
89
+ 255, 255, 255, 255, 255
90
+ ])
91
+ g = np.array([
92
+ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 5, 4, 4, 3, 3,
93
+ 2, 2, 1, 1, 0, 0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111,
94
+ 119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239,
95
+ 247, 255, 247, 239, 231, 223, 215, 207, 199, 191, 183, 175, 167, 159, 151, 143,
96
+ 135, 128, 129, 131, 132, 134, 135, 137, 139, 140, 142, 143, 145, 147, 148, 150,
97
+ 151, 153, 154, 156, 158, 159, 161, 162, 164, 166, 167, 169, 170, 172, 174, 175,
98
+ 177, 178, 180, 181, 183, 185, 186, 188, 189, 191, 193, 194, 196, 197, 199, 201,
99
+ 202, 204, 205, 207, 208, 210, 212, 213, 215, 216, 218, 220, 221, 223, 224, 226,
100
+ 228, 229, 231, 232, 234, 235, 237, 239, 240, 242, 243, 245, 247, 248, 250, 251,
101
+ 253, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199,
102
+ 195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135,
103
+ 131, 127, 123, 119, 115, 111, 107, 103, 99, 95, 91, 87, 83, 79, 75, 71, 67, 63,
104
+ 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 0, 8, 16, 24, 32, 41,
105
+ 49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180,
106
+ 189, 197, 205, 213, 222, 230, 238, 246, 254
107
+ ])
108
+ b = np.array([
109
+ 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143,
110
+ 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 255, 255,
111
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
112
+ 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 251, 247,
113
+ 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183,
114
+ 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 128, 126, 124,
115
+ 122, 120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100, 98, 96, 94, 92, 90,
116
+ 88, 86, 84, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50,
117
+ 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10,
118
+ 8, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
119
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
120
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74,
121
+ 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205,
122
+ 213, 222, 230, 238, 246, 254
123
+ ])
124
+ color = (np.vstack((r, g, b)).T).astype(np.uint8)
125
+ spectral = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
126
+ return spectral
127
+
128
+
129
+ def make_cmap(cm=0):
130
+ # make a single channel colormap
131
+ r = np.arange(0, 256)
132
+ color = np.zeros((256, 3))
133
+ color[:, cm] = r
134
+ color = color.astype(np.uint8)
135
+ cmap = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
136
+ return cmap
137
+
138
+
139
+ def run(image=None):
140
+ from ..io import logger_setup
141
+ logger, log_file = logger_setup()
142
+ # Always start by initializing Qt (only once per application)
143
+ warnings.filterwarnings("ignore")
144
+ app = QApplication(sys.argv)
145
+ icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
146
+ guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png")
147
+ if not icon_path.is_file():
148
+ cp_dir = pathlib.Path.home().joinpath(".cellpose")
149
+ cp_dir.mkdir(exist_ok=True)
150
+ print("downloading logo")
151
+ download_url_to_file(
152
+ "https://www.cellpose.org/static/images/cellpose_transparent.png",
153
+ icon_path, progress=True)
154
+ if not guip_path.is_file():
155
+ print("downloading help window image")
156
+ download_url_to_file("https://www.cellpose.org/static/images/cellposeSAM_gui.png",
157
+ guip_path, progress=True)
158
+ icon_path = str(icon_path.resolve())
159
+ app_icon = QtGui.QIcon()
160
+ app_icon.addFile(icon_path, QtCore.QSize(16, 16))
161
+ app_icon.addFile(icon_path, QtCore.QSize(24, 24))
162
+ app_icon.addFile(icon_path, QtCore.QSize(32, 32))
163
+ app_icon.addFile(icon_path, QtCore.QSize(48, 48))
164
+ app_icon.addFile(icon_path, QtCore.QSize(64, 64))
165
+ app_icon.addFile(icon_path, QtCore.QSize(256, 256))
166
+ app.setWindowIcon(app_icon)
167
+ app.setStyle("Fusion")
168
+ app.setPalette(guiparts.DarkPalette())
169
+ MainW(image=image, logger=logger)
170
+ ret = app.exec_()
171
+ sys.exit(ret)
172
+
173
+
174
+ class MainW(QMainWindow):
175
+
176
+ def __init__(self, image=None, logger=None):
177
+ super(MainW, self).__init__()
178
+
179
+ self.logger = logger
180
+ pg.setConfigOptions(imageAxisOrder="row-major")
181
+ self.setGeometry(50, 50, 1200, 1000)
182
+ self.setWindowTitle(f"cellpose v{version}")
183
+ self.cp_path = os.path.dirname(os.path.realpath(__file__))
184
+ app_icon = QtGui.QIcon()
185
+ icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
186
+ icon_path = str(icon_path.resolve())
187
+ app_icon.addFile(icon_path, QtCore.QSize(16, 16))
188
+ app_icon.addFile(icon_path, QtCore.QSize(24, 24))
189
+ app_icon.addFile(icon_path, QtCore.QSize(32, 32))
190
+ app_icon.addFile(icon_path, QtCore.QSize(48, 48))
191
+ app_icon.addFile(icon_path, QtCore.QSize(64, 64))
192
+ app_icon.addFile(icon_path, QtCore.QSize(256, 256))
193
+ self.setWindowIcon(app_icon)
194
+ # rgb(150,255,150)
195
+ self.setStyleSheet(guiparts.stylesheet())
196
+
197
+ menus.mainmenu(self)
198
+ menus.editmenu(self)
199
+ menus.modelmenu(self)
200
+ menus.helpmenu(self)
201
+
202
+ self.stylePressed = """QPushButton {Text-align: center;
203
+ background-color: rgb(150,50,150);
204
+ border-color: white;
205
+ color:white;}
206
+ QToolTip {
207
+ background-color: black;
208
+ color: white;
209
+ border: black solid 1px
210
+ }"""
211
+ self.styleUnpressed = """QPushButton {Text-align: center;
212
+ background-color: rgb(50,50,50);
213
+ border-color: white;
214
+ color:white;}
215
+ QToolTip {
216
+ background-color: black;
217
+ color: white;
218
+ border: black solid 1px
219
+ }"""
220
+ self.loaded = False
221
+
222
+ # ---- MAIN WIDGET LAYOUT ---- #
223
+ self.cwidget = QWidget(self)
224
+ self.lmain = QGridLayout()
225
+ self.cwidget.setLayout(self.lmain)
226
+ self.setCentralWidget(self.cwidget)
227
+ self.lmain.setVerticalSpacing(0)
228
+ self.lmain.setContentsMargins(0, 0, 0, 10)
229
+
230
+ self.imask = 0
231
+ self.scrollarea = QScrollArea()
232
+ self.scrollarea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
233
+ self.scrollarea.setStyleSheet("""QScrollArea { border: none }""")
234
+ self.scrollarea.setWidgetResizable(True)
235
+ self.swidget = QWidget(self)
236
+ self.scrollarea.setWidget(self.swidget)
237
+ self.l0 = QGridLayout()
238
+ self.swidget.setLayout(self.l0)
239
+ b = self.make_buttons()
240
+ self.lmain.addWidget(self.scrollarea, 0, 0, 39, 9)
241
+
242
+ # ---- drawing area ---- #
243
+ self.win = pg.GraphicsLayoutWidget()
244
+
245
+ self.lmain.addWidget(self.win, 0, 9, 40, 30)
246
+
247
+ self.win.scene().sigMouseClicked.connect(self.plot_clicked)
248
+ self.win.scene().sigMouseMoved.connect(self.mouse_moved)
249
+ self.make_viewbox()
250
+ self.lmain.setColumnStretch(10, 1)
251
+ bwrmap = make_bwr()
252
+ self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False)
253
+ self.cmap = []
254
+ # spectral colormap
255
+ self.cmap.append(make_spectral().getLookupTable(start=0.0, stop=255.0,
256
+ alpha=False))
257
+ # single channel colormaps
258
+ for i in range(3):
259
+ self.cmap.append(
260
+ make_cmap(i).getLookupTable(start=0.0, stop=255.0, alpha=False))
261
+
262
+ if MATPLOTLIB:
263
+ self.colormap = (plt.get_cmap("gist_ncar")(np.linspace(0.0, .9, 1000000)) *
264
+ 255).astype(np.uint8)
265
+ np.random.seed(42) # make colors stable
266
+ self.colormap = self.colormap[np.random.permutation(1000000)]
267
+ else:
268
+ np.random.seed(42) # make colors stable
269
+ self.colormap = ((np.random.rand(1000000, 3) * 0.8 + 0.1) * 255).astype(
270
+ np.uint8)
271
+ self.NZ = 1
272
+ self.restore = None
273
+ self.ratio = 1.
274
+ self.reset()
275
+
276
+ # This needs to go after .reset() is called to get state fully set up:
277
+ self.autobtn.checkStateChanged.connect(self.compute_saturation_if_checked)
278
+
279
+ self.load_3D = False
280
+
281
+ # if called with image, load it
282
+ if image is not None:
283
+ self.filename = image
284
+ io._load_image(self, self.filename)
285
+
286
+ # training settings
287
+ d = datetime.datetime.now()
288
+ self.training_params = {
289
+ "model_index": 0,
290
+ "learning_rate": 1e-5,
291
+ "weight_decay": 0.1,
292
+ "n_epochs": 100,
293
+ "model_name": "cpsam" + d.strftime("_%Y%m%d_%H%M%S"),
294
+ }
295
+
296
+ self.stitch_threshold = 0.
297
+ self.flow3D_smooth = 0.
298
+ self.anisotropy = 1.
299
+ self.min_size = 15
300
+
301
+ self.setAcceptDrops(True)
302
+ self.win.show()
303
+ self.show()
304
+
305
+ def help_window(self):
306
+ HW = guiparts.HelpWindow(self)
307
+ HW.show()
308
+
309
+ def train_help_window(self):
310
+ THW = guiparts.TrainHelpWindow(self)
311
+ THW.show()
312
+
313
+ def gui_window(self):
314
+ EG = guiparts.ExampleGUI(self)
315
+ EG.show()
316
+
317
+ def make_buttons(self):
318
+ self.boldfont = QtGui.QFont("Arial", 11, QtGui.QFont.Bold)
319
+ self.boldmedfont = QtGui.QFont("Arial", 9, QtGui.QFont.Bold)
320
+ self.medfont = QtGui.QFont("Arial", 9)
321
+ self.smallfont = QtGui.QFont("Arial", 8)
322
+
323
+ b = 0
324
+ self.satBox = QGroupBox("Views")
325
+ self.satBox.setFont(self.boldfont)
326
+ self.satBoxG = QGridLayout()
327
+ self.satBox.setLayout(self.satBoxG)
328
+ self.l0.addWidget(self.satBox, b, 0, 1, 9)
329
+
330
+ widget_row = 0
331
+ self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob
332
+ self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B
333
+ self.RGBDropDown = QComboBox()
334
+ self.RGBDropDown.addItems(
335
+ ["RGB", "red=R", "green=G", "blue=B", "gray", "spectral"])
336
+ self.RGBDropDown.setFont(self.medfont)
337
+ self.RGBDropDown.currentIndexChanged.connect(self.color_choose)
338
+ self.satBoxG.addWidget(self.RGBDropDown, widget_row, 0, 1, 3)
339
+
340
+ label = QLabel("<p>[&uarr; / &darr; or W/S]</p>")
341
+ label.setFont(self.smallfont)
342
+ self.satBoxG.addWidget(label, widget_row, 3, 1, 3)
343
+ label = QLabel("[R / G / B \n toggles color ]")
344
+ label.setFont(self.smallfont)
345
+ self.satBoxG.addWidget(label, widget_row, 6, 1, 3)
346
+
347
+ widget_row += 1
348
+ self.ViewDropDown = QComboBox()
349
+ self.ViewDropDown.addItems(["image", "gradXY", "cellprob", "restored"])
350
+ self.ViewDropDown.setFont(self.medfont)
351
+ self.ViewDropDown.model().item(3).setEnabled(False)
352
+ self.ViewDropDown.currentIndexChanged.connect(self.update_plot)
353
+ self.satBoxG.addWidget(self.ViewDropDown, widget_row, 0, 2, 3)
354
+
355
+ label = QLabel("[pageup / pagedown]")
356
+ label.setFont(self.smallfont)
357
+ self.satBoxG.addWidget(label, widget_row, 3, 1, 5)
358
+
359
+ widget_row += 2
360
+ label = QLabel("")
361
+ label.setToolTip(
362
+ "NOTE: manually changing the saturation bars does not affect normalization in segmentation"
363
+ )
364
+ self.satBoxG.addWidget(label, widget_row, 0, 1, 5)
365
+
366
+ self.autobtn = QCheckBox("auto-adjust saturation")
367
+ self.autobtn.setToolTip("sets scale-bars as normalized for segmentation")
368
+ self.autobtn.setFont(self.medfont)
369
+ self.autobtn.setChecked(True)
370
+ self.satBoxG.addWidget(self.autobtn, widget_row, 1, 1, 8)
371
+
372
+ widget_row += 1
373
+ self.sliders = []
374
+ colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [100, 100, 100]]
375
+ colornames = ["red", "Chartreuse", "DodgerBlue"]
376
+ names = ["red", "green", "blue"]
377
+ for r in range(3):
378
+ widget_row += 1
379
+ if r == 0:
380
+ label = QLabel('<font color="gray">gray/</font><br>red')
381
+ else:
382
+ label = QLabel(names[r] + ":")
383
+ label.setStyleSheet(f"color: {colornames[r]}")
384
+ label.setFont(self.boldmedfont)
385
+ self.satBoxG.addWidget(label, widget_row, 0, 1, 2)
386
+ self.sliders.append(Slider(self, names[r], colors[r]))
387
+ self.sliders[-1].setMinimum(-.1)
388
+ self.sliders[-1].setMaximum(255.1)
389
+ self.sliders[-1].setValue([0, 255])
390
+ self.sliders[-1].setToolTip(
391
+ "NOTE: manually changing the saturation bars does not affect normalization in segmentation"
392
+ )
393
+ self.satBoxG.addWidget(self.sliders[-1], widget_row, 2, 1, 7)
394
+
395
+ b += 1
396
+ self.drawBox = QGroupBox("Drawing")
397
+ self.drawBox.setFont(self.boldfont)
398
+ self.drawBoxG = QGridLayout()
399
+ self.drawBox.setLayout(self.drawBoxG)
400
+ self.l0.addWidget(self.drawBox, b, 0, 1, 9)
401
+ self.autosave = True
402
+
403
+ widget_row = 0
404
+ self.brush_size = 3
405
+ self.BrushChoose = QComboBox()
406
+ self.BrushChoose.addItems(["1", "3", "5", "7", "9"])
407
+ self.BrushChoose.currentIndexChanged.connect(self.brush_choose)
408
+ self.BrushChoose.setFixedWidth(40)
409
+ self.BrushChoose.setFont(self.medfont)
410
+ self.drawBoxG.addWidget(self.BrushChoose, widget_row, 3, 1, 2)
411
+ label = QLabel("brush size:")
412
+ label.setFont(self.medfont)
413
+ self.drawBoxG.addWidget(label, widget_row, 0, 1, 3)
414
+
415
+ widget_row += 1
416
+ # turn off masks
417
+ self.layer_off = False
418
+ self.masksOn = True
419
+ self.MCheckBox = QCheckBox("MASKS ON [X]")
420
+ self.MCheckBox.setFont(self.medfont)
421
+ self.MCheckBox.setChecked(True)
422
+ self.MCheckBox.toggled.connect(self.toggle_masks)
423
+ self.drawBoxG.addWidget(self.MCheckBox, widget_row, 0, 1, 5)
424
+
425
+ widget_row += 1
426
+ # turn off outlines
427
+ self.outlinesOn = False # turn off by default
428
+ self.OCheckBox = QCheckBox("outlines on [Z]")
429
+ self.OCheckBox.setFont(self.medfont)
430
+ self.drawBoxG.addWidget(self.OCheckBox, widget_row, 0, 1, 5)
431
+ self.OCheckBox.setChecked(False)
432
+ self.OCheckBox.toggled.connect(self.toggle_masks)
433
+
434
+ widget_row += 1
435
+ self.SCheckBox = QCheckBox("single stroke")
436
+ self.SCheckBox.setFont(self.medfont)
437
+ self.SCheckBox.setChecked(True)
438
+ self.SCheckBox.toggled.connect(self.autosave_on)
439
+ self.SCheckBox.setEnabled(True)
440
+ self.drawBoxG.addWidget(self.SCheckBox, widget_row, 0, 1, 5)
441
+
442
+ # buttons for deleting multiple cells
443
+ self.deleteBox = QGroupBox("delete multiple ROIs")
444
+ self.deleteBox.setStyleSheet("color: rgb(200, 200, 200)")
445
+ self.deleteBox.setFont(self.medfont)
446
+ self.deleteBoxG = QGridLayout()
447
+ self.deleteBox.setLayout(self.deleteBoxG)
448
+ self.drawBoxG.addWidget(self.deleteBox, 0, 5, 4, 4)
449
+ self.MakeDeletionRegionButton = QPushButton("region-select")
450
+ self.MakeDeletionRegionButton.clicked.connect(self.remove_region_cells)
451
+ self.deleteBoxG.addWidget(self.MakeDeletionRegionButton, 0, 0, 1, 4)
452
+ self.MakeDeletionRegionButton.setFont(self.smallfont)
453
+ self.MakeDeletionRegionButton.setFixedWidth(70)
454
+ self.DeleteMultipleROIButton = QPushButton("click-select")
455
+ self.DeleteMultipleROIButton.clicked.connect(self.delete_multiple_cells)
456
+ self.deleteBoxG.addWidget(self.DeleteMultipleROIButton, 1, 0, 1, 4)
457
+ self.DeleteMultipleROIButton.setFont(self.smallfont)
458
+ self.DeleteMultipleROIButton.setFixedWidth(70)
459
+ self.DoneDeleteMultipleROIButton = QPushButton("done")
460
+ self.DoneDeleteMultipleROIButton.clicked.connect(
461
+ self.done_remove_multiple_cells)
462
+ self.deleteBoxG.addWidget(self.DoneDeleteMultipleROIButton, 2, 0, 1, 2)
463
+ self.DoneDeleteMultipleROIButton.setFont(self.smallfont)
464
+ self.DoneDeleteMultipleROIButton.setFixedWidth(35)
465
+ self.CancelDeleteMultipleROIButton = QPushButton("cancel")
466
+ self.CancelDeleteMultipleROIButton.clicked.connect(self.cancel_remove_multiple)
467
+ self.deleteBoxG.addWidget(self.CancelDeleteMultipleROIButton, 2, 2, 1, 2)
468
+ self.CancelDeleteMultipleROIButton.setFont(self.smallfont)
469
+ self.CancelDeleteMultipleROIButton.setFixedWidth(35)
470
+
471
+ b += 1
472
+ widget_row = 0
473
+ self.segBox = QGroupBox("Segmentation")
474
+ self.segBoxG = QGridLayout()
475
+ self.segBox.setLayout(self.segBoxG)
476
+ self.l0.addWidget(self.segBox, b, 0, 1, 9)
477
+ self.segBox.setFont(self.boldfont)
478
+
479
+ widget_row += 1
480
+
481
+ # use GPU
482
+ self.useGPU = QCheckBox("use GPU")
483
+ self.useGPU.setToolTip(
484
+ "if you have specially installed the <i>cuda</i> version of torch, then you can activate this"
485
+ )
486
+ self.useGPU.setFont(self.medfont)
487
+ self.check_gpu()
488
+ self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3)
489
+
490
+ # compute segmentation with general models
491
+ self.net_text = ["run CPSAM"]
492
+ nett = ["cellpose super-generalist model"]
493
+
494
+ self.StyleButtons = []
495
+ jj = 4
496
+ for j in range(len(self.net_text)):
497
+ self.StyleButtons.append(
498
+ guiparts.ModelButton(self, self.net_text[j], self.net_text[j]))
499
+ w = 5
500
+ self.segBoxG.addWidget(self.StyleButtons[-1], widget_row, jj, 1, w)
501
+ jj += w
502
+ self.StyleButtons[-1].setToolTip(nett[j])
503
+
504
+ widget_row += 1
505
+ self.ncells = guiparts.ObservableVariable(0)
506
+ self.roi_count = QLabel()
507
+ self.roi_count.setFont(self.boldfont)
508
+ self.roi_count.setAlignment(QtCore.Qt.AlignLeft)
509
+ self.ncells.valueChanged.connect(
510
+ lambda n: self.roi_count.setText(f'{str(n)} ROIs')
511
+ )
512
+
513
+ self.segBoxG.addWidget(self.roi_count, widget_row, 0, 1, 4)
514
+
515
+ self.progress = QProgressBar(self)
516
+ self.segBoxG.addWidget(self.progress, widget_row, 4, 1, 5)
517
+
518
+ widget_row += 1
519
+
520
+ ############################### Segmentation settings ###############################
521
+ self.additional_seg_settings_qcollapsible = QCollapsible("additional settings")
522
+ self.additional_seg_settings_qcollapsible.setFont(self.medfont)
523
+ self.additional_seg_settings_qcollapsible._toggle_btn.setFont(self.medfont)
524
+ self.segmentation_settings = guiparts.SegmentationSettings(self.medfont)
525
+ self.additional_seg_settings_qcollapsible.setContent(self.segmentation_settings)
526
+ self.segBoxG.addWidget(self.additional_seg_settings_qcollapsible, widget_row, 0, 1, 9)
527
+
528
+ # connect edits to image processing steps:
529
+ self.segmentation_settings.diameter_box.editingFinished.connect(self.update_scale)
530
+ self.segmentation_settings.flow_threshold_box.returnPressed.connect(self.compute_cprob)
531
+ self.segmentation_settings.cellprob_threshold_box.returnPressed.connect(self.compute_cprob)
532
+ self.segmentation_settings.niter_box.returnPressed.connect(self.compute_cprob)
533
+
534
+ # Needed to do this for the drop down to not be open on startup
535
+ self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(True)
536
+ self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(False)
537
+
538
+ b += 1
539
+ self.modelBox = QGroupBox("user-trained models")
540
+ self.modelBoxG = QGridLayout()
541
+ self.modelBox.setLayout(self.modelBoxG)
542
+ self.l0.addWidget(self.modelBox, b, 0, 1, 9)
543
+ self.modelBox.setFont(self.boldfont)
544
+ # choose models
545
+ self.ModelChooseC = QComboBox()
546
+ self.ModelChooseC.setFont(self.medfont)
547
+ current_index = 0
548
+ self.ModelChooseC.addItems(["custom models"])
549
+ if len(self.model_strings) > 0:
550
+ self.ModelChooseC.addItems(self.model_strings)
551
+ self.ModelChooseC.setFixedWidth(175)
552
+ self.ModelChooseC.setCurrentIndex(current_index)
553
+ tipstr = 'add or train your own models in the "Models" file menu and choose model here'
554
+ self.ModelChooseC.setToolTip(tipstr)
555
+ self.ModelChooseC.activated.connect(lambda: self.model_choose(custom=True))
556
+ self.modelBoxG.addWidget(self.ModelChooseC, widget_row, 0, 1, 8)
557
+
558
+ # compute segmentation w/ custom model
559
+ self.ModelButtonC = QPushButton(u"run")
560
+ self.ModelButtonC.setFont(self.medfont)
561
+ self.ModelButtonC.setFixedWidth(35)
562
+ self.ModelButtonC.clicked.connect(
563
+ lambda: self.compute_segmentation(custom=True))
564
+ self.modelBoxG.addWidget(self.ModelButtonC, widget_row, 8, 1, 1)
565
+ self.ModelButtonC.setEnabled(False)
566
+
567
+
568
+ b += 1
569
+ self.filterBox = QGroupBox("Image filtering")
570
+ self.filterBox.setFont(self.boldfont)
571
+ self.filterBox_grid_layout = QGridLayout()
572
+ self.filterBox.setLayout(self.filterBox_grid_layout)
573
+ self.l0.addWidget(self.filterBox, b, 0, 1, 9)
574
+
575
+ widget_row = 0
576
+
577
+ # Filtering
578
+ self.FilterButtons = []
579
+ nett = [
580
+ "clear restore/filter",
581
+ "filter image (settings below)",
582
+ ]
583
+ self.filter_text = ["none",
584
+ "filter",
585
+ ]
586
+ self.restore = None
587
+ self.ratio = 1.
588
+ jj = 0
589
+ w = 3
590
+ for j in range(len(self.filter_text)):
591
+ self.FilterButtons.append(
592
+ guiparts.FilterButton(self, self.filter_text[j]))
593
+ self.filterBox_grid_layout.addWidget(self.FilterButtons[-1], widget_row, jj, 1, w)
594
+ self.FilterButtons[-1].setFixedWidth(75)
595
+ self.FilterButtons[-1].setToolTip(nett[j])
596
+ self.FilterButtons[-1].setFont(self.medfont)
597
+ widget_row += 1 if j%2==1 else 0
598
+ jj = 0 if j%2==1 else jj + w
599
+
600
+ self.save_norm = QCheckBox("save restored/filtered image")
601
+ self.save_norm.setFont(self.medfont)
602
+ self.save_norm.setToolTip("save restored/filtered image in _seg.npy file")
603
+ self.save_norm.setChecked(True)
604
+
605
+ widget_row += 2
606
+
607
+ self.filtBox = QCollapsible("custom filter settings")
608
+ self.filtBox._toggle_btn.setFont(self.medfont)
609
+ self.filtBoxG = QGridLayout()
610
+ _content = QWidget()
611
+ _content.setLayout(self.filtBoxG)
612
+ _content.setMaximumHeight(0)
613
+ _content.setMinimumHeight(0)
614
+ self.filtBox.setContent(_content)
615
+ self.filterBox_grid_layout.addWidget(self.filtBox, widget_row, 0, 1, 9)
616
+
617
+ self.filt_vals = [0., 0., 0., 0.]
618
+ self.filt_edits = []
619
+ labels = [
620
+ "sharpen\nradius", "smooth\nradius", "tile_norm\nblocksize",
621
+ "tile_norm\nsmooth3D"
622
+ ]
623
+ tooltips = [
624
+ "set size of surround-subtraction filter for sharpening image",
625
+ "set size of gaussian filter for smoothing image",
626
+ "set size of tiles to use to normalize image",
627
+ "set amount of smoothing of normalization values across planes"
628
+ ]
629
+
630
+ for p in range(4):
631
+ label = QLabel(f"{labels[p]}:")
632
+ label.setToolTip(tooltips[p])
633
+ label.setFont(self.medfont)
634
+ self.filtBoxG.addWidget(label, widget_row + p // 2, 4 * (p % 2), 1, 2)
635
+ self.filt_edits.append(QLineEdit())
636
+ self.filt_edits[p].setText(str(self.filt_vals[p]))
637
+ self.filt_edits[p].setFixedWidth(40)
638
+ self.filt_edits[p].setFont(self.medfont)
639
+ self.filtBoxG.addWidget(self.filt_edits[p], widget_row + p // 2, 4 * (p % 2) + 2, 1,
640
+ 2)
641
+ self.filt_edits[p].setToolTip(tooltips[p])
642
+
643
+ widget_row += 3
644
+ self.norm3D_cb = QCheckBox("norm3D")
645
+ self.norm3D_cb.setFont(self.medfont)
646
+ self.norm3D_cb.setChecked(True)
647
+ self.norm3D_cb.setToolTip("run same normalization across planes")
648
+ self.filtBoxG.addWidget(self.norm3D_cb, widget_row, 0, 1, 3)
649
+
650
+
651
+ return b
652
+
653
+ def level_change(self, r):
654
+ r = ["red", "green", "blue"].index(r)
655
+ if self.loaded:
656
+ sval = self.sliders[r].value()
657
+ self.saturation[r][self.currentZ] = sval
658
+ if not self.autobtn.isChecked():
659
+ for r in range(3):
660
+ for i in range(len(self.saturation[r])):
661
+ self.saturation[r][i] = self.saturation[r][self.currentZ]
662
+ self.update_plot()
663
+
664
+ def keyPressEvent(self, event):
665
+ if self.loaded:
666
+ if not (event.modifiers() &
667
+ (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
668
+ QtCore.Qt.AltModifier) or self.in_stroke):
669
+ updated = False
670
+ if len(self.current_point_set) > 0:
671
+ if event.key() == QtCore.Qt.Key_Return:
672
+ self.add_set()
673
+ else:
674
+ nviews = self.ViewDropDown.count() - 1
675
+ nviews += int(
676
+ self.ViewDropDown.model().item(self.ViewDropDown.count() -
677
+ 1).isEnabled())
678
+ if event.key() == QtCore.Qt.Key_X:
679
+ self.MCheckBox.toggle()
680
+ if event.key() == QtCore.Qt.Key_Z:
681
+ self.OCheckBox.toggle()
682
+ if event.key() == QtCore.Qt.Key_Left or event.key(
683
+ ) == QtCore.Qt.Key_A:
684
+ self.get_prev_image()
685
+ elif event.key() == QtCore.Qt.Key_Right or event.key(
686
+ ) == QtCore.Qt.Key_D:
687
+ self.get_next_image()
688
+ elif event.key() == QtCore.Qt.Key_PageDown:
689
+ self.view = (self.view + 1) % (nviews)
690
+ self.ViewDropDown.setCurrentIndex(self.view)
691
+ elif event.key() == QtCore.Qt.Key_PageUp:
692
+ self.view = (self.view - 1) % (nviews)
693
+ self.ViewDropDown.setCurrentIndex(self.view)
694
+
695
+ # can change background or stroke size if cell not finished
696
+ if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
697
+ self.color = (self.color - 1) % (6)
698
+ self.RGBDropDown.setCurrentIndex(self.color)
699
+ elif event.key() == QtCore.Qt.Key_Down or event.key(
700
+ ) == QtCore.Qt.Key_S:
701
+ self.color = (self.color + 1) % (6)
702
+ self.RGBDropDown.setCurrentIndex(self.color)
703
+ elif event.key() == QtCore.Qt.Key_R:
704
+ if self.color != 1:
705
+ self.color = 1
706
+ else:
707
+ self.color = 0
708
+ self.RGBDropDown.setCurrentIndex(self.color)
709
+ elif event.key() == QtCore.Qt.Key_G:
710
+ if self.color != 2:
711
+ self.color = 2
712
+ else:
713
+ self.color = 0
714
+ self.RGBDropDown.setCurrentIndex(self.color)
715
+ elif event.key() == QtCore.Qt.Key_B:
716
+ if self.color != 3:
717
+ self.color = 3
718
+ else:
719
+ self.color = 0
720
+ self.RGBDropDown.setCurrentIndex(self.color)
721
+ elif (event.key() == QtCore.Qt.Key_Comma or
722
+ event.key() == QtCore.Qt.Key_Period):
723
+ count = self.BrushChoose.count()
724
+ gci = self.BrushChoose.currentIndex()
725
+ if event.key() == QtCore.Qt.Key_Comma:
726
+ gci = max(0, gci - 1)
727
+ else:
728
+ gci = min(count - 1, gci + 1)
729
+ self.BrushChoose.setCurrentIndex(gci)
730
+ self.brush_choose()
731
+ if not updated:
732
+ self.update_plot()
733
+ if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
734
+ self.p0.keyPressEvent(event)
735
+
736
+ def autosave_on(self):
737
+ if self.SCheckBox.isChecked():
738
+ self.autosave = True
739
+ else:
740
+ self.autosave = False
741
+
742
+ def check_gpu(self, torch=True):
743
+ # also decide whether or not to use torch
744
+ self.useGPU.setChecked(False)
745
+ self.useGPU.setEnabled(False)
746
+ if core.use_gpu(use_torch=True):
747
+ self.useGPU.setEnabled(True)
748
+ self.useGPU.setChecked(True)
749
+ else:
750
+ self.useGPU.setStyleSheet("color: rgb(80,80,80);")
751
+
752
+
753
+ def model_choose(self, custom=False):
754
+ index = self.ModelChooseC.currentIndex(
755
+ ) if custom else self.ModelChooseB.currentIndex()
756
+ if index > 0:
757
+ if custom:
758
+ model_name = self.ModelChooseC.currentText()
759
+ else:
760
+ model_name = self.net_names[index - 1]
761
+ print(f"GUI_INFO: selected model {model_name}, loading now")
762
+ self.initialize_model(model_name=model_name, custom=custom)
763
+
764
+ def toggle_scale(self):
765
+ if self.scale_on:
766
+ self.p0.removeItem(self.scale)
767
+ self.scale_on = False
768
+ else:
769
+ self.p0.addItem(self.scale)
770
+ self.scale_on = True
771
+
772
+ def enable_buttons(self):
773
+ if len(self.model_strings) > 0:
774
+ self.ModelButtonC.setEnabled(True)
775
+ for i in range(len(self.StyleButtons)):
776
+ self.StyleButtons[i].setEnabled(True)
777
+
778
+ for i in range(len(self.FilterButtons)):
779
+ self.FilterButtons[i].setEnabled(True)
780
+ if self.load_3D:
781
+ self.FilterButtons[-2].setEnabled(False)
782
+
783
+ self.newmodel.setEnabled(True)
784
+ self.loadMasks.setEnabled(True)
785
+
786
+ for n in range(self.nchan):
787
+ self.sliders[n].setEnabled(True)
788
+ for n in range(self.nchan, 3):
789
+ self.sliders[n].setEnabled(True)
790
+
791
+ self.toggle_mask_ops()
792
+
793
+ self.update_plot()
794
+ self.setWindowTitle(self.filename)
795
+
796
+ def disable_buttons_removeROIs(self):
797
+ if len(self.model_strings) > 0:
798
+ self.ModelButtonC.setEnabled(False)
799
+ for i in range(len(self.StyleButtons)):
800
+ self.StyleButtons[i].setEnabled(False)
801
+ self.newmodel.setEnabled(False)
802
+ self.loadMasks.setEnabled(False)
803
+ self.saveSet.setEnabled(False)
804
+ self.savePNG.setEnabled(False)
805
+ self.saveFlows.setEnabled(False)
806
+ self.saveOutlines.setEnabled(False)
807
+ self.saveROIs.setEnabled(False)
808
+
809
+ self.MakeDeletionRegionButton.setEnabled(False)
810
+ self.DeleteMultipleROIButton.setEnabled(False)
811
+ self.DoneDeleteMultipleROIButton.setEnabled(True)
812
+ self.CancelDeleteMultipleROIButton.setEnabled(True)
813
+
814
+ def toggle_mask_ops(self):
815
+ self.update_layer()
816
+ self.toggle_saving()
817
+ self.toggle_removals()
818
+
819
+ def toggle_saving(self):
820
+ if self.ncells > 0:
821
+ self.saveSet.setEnabled(True)
822
+ self.savePNG.setEnabled(True)
823
+ self.saveFlows.setEnabled(True)
824
+ self.saveOutlines.setEnabled(True)
825
+ self.saveROIs.setEnabled(True)
826
+ else:
827
+ self.saveSet.setEnabled(False)
828
+ self.savePNG.setEnabled(False)
829
+ self.saveFlows.setEnabled(False)
830
+ self.saveOutlines.setEnabled(False)
831
+ self.saveROIs.setEnabled(False)
832
+
833
+ def toggle_removals(self):
834
+ if self.ncells > 0:
835
+ self.ClearButton.setEnabled(True)
836
+ self.remcell.setEnabled(True)
837
+ self.undo.setEnabled(True)
838
+ self.MakeDeletionRegionButton.setEnabled(True)
839
+ self.DeleteMultipleROIButton.setEnabled(True)
840
+ self.DoneDeleteMultipleROIButton.setEnabled(False)
841
+ self.CancelDeleteMultipleROIButton.setEnabled(False)
842
+ else:
843
+ self.ClearButton.setEnabled(False)
844
+ self.remcell.setEnabled(False)
845
+ self.undo.setEnabled(False)
846
+ self.MakeDeletionRegionButton.setEnabled(False)
847
+ self.DeleteMultipleROIButton.setEnabled(False)
848
+ self.DoneDeleteMultipleROIButton.setEnabled(False)
849
+ self.CancelDeleteMultipleROIButton.setEnabled(False)
850
+
851
+ def remove_action(self):
852
+ if self.selected > 0:
853
+ self.remove_cell(self.selected)
854
+
855
+ def undo_action(self):
856
+ if (len(self.strokes) > 0 and self.strokes[-1][0][0] == self.currentZ):
857
+ self.remove_stroke()
858
+ else:
859
+ # remove previous cell
860
+ if self.ncells > 0:
861
+ self.remove_cell(self.ncells.get())
862
+
863
+ def undo_remove_action(self):
864
+ self.undo_remove_cell()
865
+
866
+ def get_files(self):
867
+ folder = os.path.dirname(self.filename)
868
+ mask_filter = "_masks"
869
+ images = get_image_files(folder, mask_filter)
870
+ fnames = [os.path.split(images[k])[-1] for k in range(len(images))]
871
+ f0 = os.path.split(self.filename)[-1]
872
+ idx = np.nonzero(np.array(fnames) == f0)[0][0]
873
+ return images, idx
874
+
875
+ def get_prev_image(self):
876
+ images, idx = self.get_files()
877
+ idx = (idx - 1) % len(images)
878
+ io._load_image(self, filename=images[idx])
879
+
880
+ def get_next_image(self, load_seg=True):
881
+ images, idx = self.get_files()
882
+ idx = (idx + 1) % len(images)
883
+ io._load_image(self, filename=images[idx], load_seg=load_seg)
884
+
885
+ def dragEnterEvent(self, event):
886
+ if event.mimeData().hasUrls():
887
+ event.accept()
888
+ else:
889
+ event.ignore()
890
+
891
+ def dropEvent(self, event):
892
+ files = [u.toLocalFile() for u in event.mimeData().urls()]
893
+ if os.path.splitext(files[0])[-1] == ".npy":
894
+ io._load_seg(self, filename=files[0], load_3D=self.load_3D)
895
+ else:
896
+ io._load_image(self, filename=files[0], load_seg=True, load_3D=self.load_3D)
897
+
898
+ def toggle_masks(self):
899
+ if self.MCheckBox.isChecked():
900
+ self.masksOn = True
901
+ else:
902
+ self.masksOn = False
903
+ if self.OCheckBox.isChecked():
904
+ self.outlinesOn = True
905
+ else:
906
+ self.outlinesOn = False
907
+ if not self.masksOn and not self.outlinesOn:
908
+ self.p0.removeItem(self.layer)
909
+ self.layer_off = True
910
+ else:
911
+ if self.layer_off:
912
+ self.p0.addItem(self.layer)
913
+ self.draw_layer()
914
+ self.update_layer()
915
+ if self.loaded:
916
+ self.update_plot()
917
+ self.update_layer()
918
+
919
+ def make_viewbox(self):
920
+ self.p0 = guiparts.ViewBoxNoRightDrag(parent=self, lockAspect=True,
921
+ name="plot1", border=[100, 100,
922
+ 100], invertY=True)
923
+ self.p0.setCursor(QtCore.Qt.CrossCursor)
924
+ self.brush_size = 3
925
+ self.win.addItem(self.p0, 0, 0, rowspan=1, colspan=1)
926
+ self.p0.setMenuEnabled(False)
927
+ self.p0.setMouseEnabled(x=True, y=True)
928
+ self.img = pg.ImageItem(viewbox=self.p0, parent=self)
929
+ self.img.autoDownsample = False
930
+ self.layer = guiparts.ImageDraw(viewbox=self.p0, parent=self)
931
+ self.layer.setLevels([0, 255])
932
+ self.scale = pg.ImageItem(viewbox=self.p0, parent=self)
933
+ self.scale.setLevels([0, 255])
934
+ self.p0.scene().contextMenuItem = self.p0
935
+ self.Ly, self.Lx = 512, 512
936
+ self.p0.addItem(self.img)
937
+ self.p0.addItem(self.layer)
938
+ self.p0.addItem(self.scale)
939
+
940
+ def reset(self):
941
+ # ---- start sets of points ---- #
942
+ self.selected = 0
943
+ self.nchan = 3
944
+ self.loaded = False
945
+ self.channel = [0, 1]
946
+ self.current_point_set = []
947
+ self.in_stroke = False
948
+ self.strokes = []
949
+ self.stroke_appended = True
950
+ self.resize = False
951
+ self.ncells.reset()
952
+ self.zdraw = []
953
+ self.removed_cell = []
954
+ self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
955
+
956
+ # -- zero out image stack -- #
957
+ self.opacity = 128 # how opaque masks should be
958
+ self.outcolor = [200, 200, 255, 200]
959
+ self.NZ, self.Ly, self.Lx = 1, 256, 256
960
+ self.saturation = self.saturation if hasattr(self, 'saturation') else []
961
+
962
+ # only adjust the saturation if auto-adjust is on:
963
+ if self.autobtn.isChecked():
964
+ for r in range(3):
965
+ self.saturation.append([[0, 255] for n in range(self.NZ)])
966
+ self.sliders[r].setValue([0, 255])
967
+ self.sliders[r].setEnabled(False)
968
+ self.sliders[r].show()
969
+ self.currentZ = 0
970
+ self.flows = [[], [], [], [], [[]]]
971
+ # masks matrix
972
+ # image matrix with a scale disk
973
+ self.stack = np.zeros((1, self.Ly, self.Lx, 3))
974
+ self.Lyr, self.Lxr = self.Ly, self.Lx
975
+ self.Ly0, self.Lx0 = self.Ly, self.Lx
976
+ self.radii = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
977
+ self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
978
+ self.cellpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
979
+ self.outpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
980
+ self.ismanual = np.zeros(0, "bool")
981
+
982
+ # -- set menus to default -- #
983
+ self.color = 0
984
+ self.RGBDropDown.setCurrentIndex(self.color)
985
+ self.view = 0
986
+ self.ViewDropDown.setCurrentIndex(0)
987
+ self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
988
+ self.delete_restore()
989
+
990
+ self.clear_all()
991
+
992
+ self.filename = []
993
+ self.loaded = False
994
+ self.recompute_masks = False
995
+
996
+ self.deleting_multiple = False
997
+ self.removing_cells_list = []
998
+ self.removing_region = False
999
+ self.remove_roi_obj = None
1000
+
1001
+ def delete_restore(self):
1002
+ """ delete restored imgs but don't reset settings """
1003
+ if hasattr(self, "stack_filtered"):
1004
+ del self.stack_filtered
1005
+ if hasattr(self, "cellpix_orig"):
1006
+ self.cellpix = self.cellpix_orig.copy()
1007
+ self.outpix = self.outpix_orig.copy()
1008
+ del self.outpix_orig, self.outpix_resize
1009
+ del self.cellpix_orig, self.cellpix_resize
1010
+
1011
+ def clear_restore(self):
1012
+ """ delete restored imgs and reset settings """
1013
+ print("GUI_INFO: clearing restored image")
1014
+ self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
1015
+ if self.ViewDropDown.currentIndex() == self.ViewDropDown.count() - 1:
1016
+ self.ViewDropDown.setCurrentIndex(0)
1017
+ self.delete_restore()
1018
+ self.restore = None
1019
+ self.ratio = 1.
1020
+ self.set_normalize_params(self.get_normalize_params())
1021
+
1022
+ def brush_choose(self):
1023
+ self.brush_size = self.BrushChoose.currentIndex() * 2 + 1
1024
+ if self.loaded:
1025
+ self.layer.setDrawKernel(kernel_size=self.brush_size)
1026
+ self.update_layer()
1027
+
1028
+ def clear_all(self):
1029
+ self.prev_selected = 0
1030
+ self.selected = 0
1031
+ if self.restore and "upsample" in self.restore:
1032
+ self.layerz = 0 * np.ones((self.Lyr, self.Lxr, 4), np.uint8)
1033
+ self.cellpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
1034
+ self.outpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
1035
+ self.cellpix_resize = self.cellpix.copy()
1036
+ self.outpix_resize = self.outpix.copy()
1037
+ self.cellpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
1038
+ self.outpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
1039
+ else:
1040
+ self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
1041
+ self.cellpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
1042
+ self.outpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
1043
+
1044
+ self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
1045
+ self.ncells.reset()
1046
+ self.toggle_removals()
1047
+ self.update_scale()
1048
+ self.update_layer()
1049
+
1050
+ def select_cell(self, idx):
1051
+ self.prev_selected = self.selected
1052
+ self.selected = idx
1053
+ if self.selected > 0:
1054
+ z = self.currentZ
1055
+ self.layerz[self.cellpix[z] == idx] = np.array(
1056
+ [255, 255, 255, self.opacity])
1057
+ self.update_layer()
1058
+
1059
+ def select_cell_multi(self, idx):
1060
+ if idx > 0:
1061
+ z = self.currentZ
1062
+ self.layerz[self.cellpix[z] == idx] = np.array(
1063
+ [255, 255, 255, self.opacity])
1064
+ self.update_layer()
1065
+
1066
+ def unselect_cell(self):
1067
+ if self.selected > 0:
1068
+ idx = self.selected
1069
+ if idx < (self.ncells.get() + 1):
1070
+ z = self.currentZ
1071
+ self.layerz[self.cellpix[z] == idx] = np.append(
1072
+ self.cellcolors[idx], self.opacity)
1073
+ if self.outlinesOn:
1074
+ self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
1075
+ np.uint8)
1076
+ #[0,0,0,self.opacity])
1077
+ self.update_layer()
1078
+ self.selected = 0
1079
+
1080
+ def unselect_cell_multi(self, idx):
1081
+ z = self.currentZ
1082
+ self.layerz[self.cellpix[z] == idx] = np.append(self.cellcolors[idx],
1083
+ self.opacity)
1084
+ if self.outlinesOn:
1085
+ self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
1086
+ np.uint8)
1087
+ # [0,0,0,self.opacity])
1088
+ self.update_layer()
1089
+
1090
+ def remove_cell(self, idx):
1091
+ if isinstance(idx, (int, np.integer)):
1092
+ idx = [idx]
1093
+ # because the function remove_single_cell updates the state of the cellpix and outpix arrays
1094
+ # by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order
1095
+ # so that the indices are correct
1096
+ idx.sort(reverse=True)
1097
+ for i in idx:
1098
+ self.remove_single_cell(i)
1099
+ self.ncells -= len(idx) # _save_sets uses ncells
1100
+ self.update_layer()
1101
+
1102
+ if self.ncells == 0:
1103
+ self.ClearButton.setEnabled(False)
1104
+ if self.NZ == 1:
1105
+ io._save_sets_with_check(self)
1106
+
1107
+
1108
+ def remove_single_cell(self, idx):
1109
+ # remove from manual array
1110
+ self.selected = 0
1111
+ if self.NZ > 1:
1112
+ zextent = ((self.cellpix == idx).sum(axis=(1, 2)) > 0).nonzero()[0]
1113
+ else:
1114
+ zextent = [0]
1115
+ for z in zextent:
1116
+ cp = self.cellpix[z] == idx
1117
+ op = self.outpix[z] == idx
1118
+ # remove from self.cellpix and self.outpix
1119
+ self.cellpix[z, cp] = 0
1120
+ self.outpix[z, op] = 0
1121
+ if z == self.currentZ:
1122
+ # remove from mask layer
1123
+ self.layerz[cp] = np.array([0, 0, 0, 0])
1124
+
1125
+ # reduce other pixels by -1
1126
+ self.cellpix[self.cellpix > idx] -= 1
1127
+ self.outpix[self.outpix > idx] -= 1
1128
+
1129
+ if self.NZ == 1:
1130
+ self.removed_cell = [
1131
+ self.ismanual[idx - 1], self.cellcolors[idx],
1132
+ np.nonzero(cp),
1133
+ np.nonzero(op)
1134
+ ]
1135
+ self.redo.setEnabled(True)
1136
+ ar, ac = self.removed_cell[2]
1137
+ d = datetime.datetime.now()
1138
+ self.track_changes.append(
1139
+ [d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]])
1140
+ # remove cell from lists
1141
+ self.ismanual = np.delete(self.ismanual, idx - 1)
1142
+ self.cellcolors = np.delete(self.cellcolors, [idx], axis=0)
1143
+ del self.zdraw[idx - 1]
1144
+ print("GUI_INFO: removed cell %d" % (idx - 1))
1145
+
1146
+ def remove_region_cells(self):
1147
+ if self.removing_cells_list:
1148
+ for idx in self.removing_cells_list:
1149
+ self.unselect_cell_multi(idx)
1150
+ self.removing_cells_list.clear()
1151
+ self.disable_buttons_removeROIs()
1152
+ self.removing_region = True
1153
+
1154
+ self.clear_multi_selected_cells()
1155
+
1156
+ # make roi region here in center of view, making ROI half the size of the view
1157
+ roi_width = self.p0.viewRect().width() / 2
1158
+ x_loc = self.p0.viewRect().x() + (roi_width / 2)
1159
+ roi_height = self.p0.viewRect().height() / 2
1160
+ y_loc = self.p0.viewRect().y() + (roi_height / 2)
1161
+
1162
+ pos = [x_loc, y_loc]
1163
+ roi = pg.RectROI(pos, [roi_width, roi_height], pen=pg.mkPen("y", width=2),
1164
+ removable=True)
1165
+ roi.sigRemoveRequested.connect(self.remove_roi)
1166
+ roi.sigRegionChangeFinished.connect(self.roi_changed)
1167
+ self.p0.addItem(roi)
1168
+ self.remove_roi_obj = roi
1169
+ self.roi_changed(roi)
1170
+
1171
+ def delete_multiple_cells(self):
1172
+ self.unselect_cell()
1173
+ self.disable_buttons_removeROIs()
1174
+ self.DoneDeleteMultipleROIButton.setEnabled(True)
1175
+ self.MakeDeletionRegionButton.setEnabled(True)
1176
+ self.CancelDeleteMultipleROIButton.setEnabled(True)
1177
+ self.deleting_multiple = True
1178
+
1179
+ def done_remove_multiple_cells(self):
1180
+ self.deleting_multiple = False
1181
+ self.removing_region = False
1182
+ self.DoneDeleteMultipleROIButton.setEnabled(False)
1183
+ self.MakeDeletionRegionButton.setEnabled(False)
1184
+ self.CancelDeleteMultipleROIButton.setEnabled(False)
1185
+
1186
+ if self.removing_cells_list:
1187
+ self.removing_cells_list = list(set(self.removing_cells_list))
1188
+ display_remove_list = [i - 1 for i in self.removing_cells_list]
1189
+ print(f"GUI_INFO: removing cells: {display_remove_list}")
1190
+ self.remove_cell(self.removing_cells_list)
1191
+ self.removing_cells_list.clear()
1192
+ self.unselect_cell()
1193
+ self.enable_buttons()
1194
+
1195
+ if self.remove_roi_obj is not None:
1196
+ self.remove_roi(self.remove_roi_obj)
1197
+
1198
+ def merge_cells(self, idx):
1199
+ self.prev_selected = self.selected
1200
+ self.selected = idx
1201
+ if self.selected != self.prev_selected:
1202
+ for z in range(self.NZ):
1203
+ ar0, ac0 = np.nonzero(self.cellpix[z] == self.prev_selected)
1204
+ ar1, ac1 = np.nonzero(self.cellpix[z] == self.selected)
1205
+ touching = np.logical_and((ar0[:, np.newaxis] - ar1) < 3,
1206
+ (ac0[:, np.newaxis] - ac1) < 3).sum()
1207
+ ar = np.hstack((ar0, ar1))
1208
+ ac = np.hstack((ac0, ac1))
1209
+ vr0, vc0 = np.nonzero(self.outpix[z] == self.prev_selected)
1210
+ vr1, vc1 = np.nonzero(self.outpix[z] == self.selected)
1211
+ self.outpix[z, vr0, vc0] = 0
1212
+ self.outpix[z, vr1, vc1] = 0
1213
+ if touching > 0:
1214
+ mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8)
1215
+ mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
1216
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
1217
+ cv2.CHAIN_APPROX_NONE)
1218
+ pvc, pvr = contours[-2][0].squeeze().T
1219
+ vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
1220
+
1221
+ else:
1222
+ vr = np.hstack((vr0, vr1))
1223
+ vc = np.hstack((vc0, vc1))
1224
+ color = self.cellcolors[self.prev_selected]
1225
+ self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected)
1226
+ self.remove_cell(self.selected)
1227
+ print("GUI_INFO: merged two cells")
1228
+ self.update_layer()
1229
+ io._save_sets_with_check(self)
1230
+ self.undo.setEnabled(False)
1231
+ self.redo.setEnabled(False)
1232
+
1233
+ def undo_remove_cell(self):
1234
+ if len(self.removed_cell) > 0:
1235
+ z = 0
1236
+ ar, ac = self.removed_cell[2]
1237
+ vr, vc = self.removed_cell[3]
1238
+ color = self.removed_cell[1]
1239
+ self.draw_mask(z, ar, ac, vr, vc, color)
1240
+ self.toggle_mask_ops()
1241
+ self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0)
1242
+ self.ncells += 1
1243
+ self.ismanual = np.append(self.ismanual, self.removed_cell[0])
1244
+ self.zdraw.append([])
1245
+ print(">>> added back removed cell")
1246
+ self.update_layer()
1247
+ io._save_sets_with_check(self)
1248
+ self.removed_cell = []
1249
+ self.redo.setEnabled(False)
1250
+
1251
+ def remove_stroke(self, delete_points=True, stroke_ind=-1):
1252
+ stroke = np.array(self.strokes[stroke_ind])
1253
+ cZ = self.currentZ
1254
+ inZ = stroke[0, 0] == cZ
1255
+ if inZ:
1256
+ outpix = self.outpix[cZ, stroke[:, 1], stroke[:, 2]] > 0
1257
+ self.layerz[stroke[~outpix, 1], stroke[~outpix, 2]] = np.array([0, 0, 0, 0])
1258
+ cellpix = self.cellpix[cZ, stroke[:, 1], stroke[:, 2]]
1259
+ ccol = self.cellcolors.copy()
1260
+ if self.selected > 0:
1261
+ ccol[self.selected] = np.array([255, 255, 255])
1262
+ col2mask = ccol[cellpix]
1263
+ if self.masksOn:
1264
+ col2mask = np.concatenate(
1265
+ (col2mask, self.opacity * (cellpix[:, np.newaxis] > 0)), axis=-1)
1266
+ else:
1267
+ col2mask = np.concatenate((col2mask, 0 * (cellpix[:, np.newaxis] > 0)),
1268
+ axis=-1)
1269
+ self.layerz[stroke[:, 1], stroke[:, 2], :] = col2mask
1270
+ if self.outlinesOn:
1271
+ self.layerz[stroke[outpix, 1], stroke[outpix,
1272
+ 2]] = np.array(self.outcolor)
1273
+ if delete_points:
1274
+ del self.current_point_set[stroke_ind]
1275
+ self.update_layer()
1276
+
1277
+ del self.strokes[stroke_ind]
1278
+
1279
+ def plot_clicked(self, event):
1280
+ if event.button()==QtCore.Qt.LeftButton \
1281
+ and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
1282
+ and not self.removing_region:
1283
+ if event.double():
1284
+ try:
1285
+ self.p0.setYRange(0, self.Ly + self.pr)
1286
+ except:
1287
+ self.p0.setYRange(0, self.Ly)
1288
+ self.p0.setXRange(0, self.Lx)
1289
+
1290
+ def cancel_remove_multiple(self):
1291
+ self.clear_multi_selected_cells()
1292
+ self.done_remove_multiple_cells()
1293
+
1294
+ def clear_multi_selected_cells(self):
1295
+ # unselect all previously selected cells:
1296
+ for idx in self.removing_cells_list:
1297
+ self.unselect_cell_multi(idx)
1298
+ self.removing_cells_list.clear()
1299
+
1300
+ def add_roi(self, roi):
1301
+ self.p0.addItem(roi)
1302
+ self.remove_roi_obj = roi
1303
+
1304
+ def remove_roi(self, roi):
1305
+ self.clear_multi_selected_cells()
1306
+ assert roi == self.remove_roi_obj
1307
+ self.remove_roi_obj = None
1308
+ self.p0.removeItem(roi)
1309
+ self.removing_region = False
1310
+
1311
+ def roi_changed(self, roi):
1312
+ # find the overlapping cells and make them selected
1313
+ pos = roi.pos()
1314
+ size = roi.size()
1315
+ x0 = int(pos.x())
1316
+ y0 = int(pos.y())
1317
+ x1 = int(pos.x() + size.x())
1318
+ y1 = int(pos.y() + size.y())
1319
+ if x0 < 0:
1320
+ x0 = 0
1321
+ if y0 < 0:
1322
+ y0 = 0
1323
+ if x1 > self.Lx:
1324
+ x1 = self.Lx
1325
+ if y1 > self.Ly:
1326
+ y1 = self.Ly
1327
+
1328
+ # find cells in that region
1329
+ cell_idxs = np.unique(self.cellpix[self.currentZ, y0:y1, x0:x1])
1330
+ cell_idxs = np.trim_zeros(cell_idxs)
1331
+ # deselect cells not in region by deselecting all and then selecting the ones in the region
1332
+ self.clear_multi_selected_cells()
1333
+
1334
+ for idx in cell_idxs:
1335
+ self.select_cell_multi(idx)
1336
+ self.removing_cells_list.append(idx)
1337
+
1338
+ self.update_layer()
1339
+
1340
+ def mouse_moved(self, pos):
1341
+ items = self.win.scene().items(pos)
1342
+
1343
+ def color_choose(self):
1344
+ self.color = self.RGBDropDown.currentIndex()
1345
+ self.view = 0
1346
+ self.ViewDropDown.setCurrentIndex(self.view)
1347
+ self.update_plot()
1348
+
1349
+ def update_plot(self):
1350
+ self.view = self.ViewDropDown.currentIndex()
1351
+ self.Ly, self.Lx, _ = self.stack[self.currentZ].shape
1352
+
1353
+ if self.view == 0 or self.view == self.ViewDropDown.count() - 1:
1354
+ image = self.stack[
1355
+ self.currentZ] if self.view == 0 else self.stack_filtered[self.currentZ]
1356
+ if self.color == 0:
1357
+ self.img.setImage(image, autoLevels=False, lut=None)
1358
+ if self.nchan > 1:
1359
+ levels = np.array([
1360
+ self.saturation[0][self.currentZ],
1361
+ self.saturation[1][self.currentZ],
1362
+ self.saturation[2][self.currentZ]
1363
+ ])
1364
+ self.img.setLevels(levels)
1365
+ else:
1366
+ self.img.setLevels(self.saturation[0][self.currentZ])
1367
+ elif self.color > 0 and self.color < 4:
1368
+ if self.nchan > 1:
1369
+ image = image[:, :, self.color - 1]
1370
+ self.img.setImage(image, autoLevels=False, lut=self.cmap[self.color])
1371
+ if self.nchan > 1:
1372
+ self.img.setLevels(self.saturation[self.color - 1][self.currentZ])
1373
+ else:
1374
+ self.img.setLevels(self.saturation[0][self.currentZ])
1375
+ elif self.color == 4:
1376
+ if self.nchan > 1:
1377
+ image = image.mean(axis=-1)
1378
+ self.img.setImage(image, autoLevels=False, lut=None)
1379
+ self.img.setLevels(self.saturation[0][self.currentZ])
1380
+ elif self.color == 5:
1381
+ if self.nchan > 1:
1382
+ image = image.mean(axis=-1)
1383
+ self.img.setImage(image, autoLevels=False, lut=self.cmap[0])
1384
+ self.img.setLevels(self.saturation[0][self.currentZ])
1385
+ else:
1386
+ image = np.zeros((self.Ly, self.Lx), np.uint8)
1387
+ if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0:
1388
+ image = self.flows[self.view - 1][self.currentZ]
1389
+ if self.view > 1:
1390
+ self.img.setImage(image, autoLevels=False, lut=self.bwr)
1391
+ else:
1392
+ self.img.setImage(image, autoLevels=False, lut=None)
1393
+ self.img.setLevels([0.0, 255.0])
1394
+
1395
+ for r in range(3):
1396
+ self.sliders[r].setValue([
1397
+ self.saturation[r][self.currentZ][0],
1398
+ self.saturation[r][self.currentZ][1]
1399
+ ])
1400
+ self.win.show()
1401
+ self.show()
1402
+
1403
+
1404
+ def update_layer(self):
1405
+ if self.masksOn or self.outlinesOn:
1406
+ self.layer.setImage(self.layerz, autoLevels=False)
1407
+ self.win.show()
1408
+ self.show()
1409
+
1410
+
1411
+ def add_set(self):
1412
+ if len(self.current_point_set) > 0:
1413
+ while len(self.strokes) > 0:
1414
+ self.remove_stroke(delete_points=False)
1415
+ if len(self.current_point_set[0]) > 8:
1416
+ color = self.colormap[self.ncells.get(), :3]
1417
+ median = self.add_mask(points=self.current_point_set, color=color)
1418
+ if median is not None:
1419
+ self.removed_cell = []
1420
+ self.toggle_mask_ops()
1421
+ self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :],
1422
+ axis=0)
1423
+ self.ncells += 1
1424
+ self.ismanual = np.append(self.ismanual, True)
1425
+ if self.NZ == 1:
1426
+ # only save after each cell if single image
1427
+ io._save_sets_with_check(self)
1428
+ else:
1429
+ print("GUI_ERROR: cell too small, not drawn")
1430
+ self.current_stroke = []
1431
+ self.strokes = []
1432
+ self.current_point_set = []
1433
+ self.update_layer()
1434
+
1435
+ def add_mask(self, points=None, color=(100, 200, 50), dense=True):
1436
+ # points is list of strokes
1437
+ points_all = np.concatenate(points, axis=0)
1438
+
1439
+ # loop over z values
1440
+ median = []
1441
+ zdraw = np.unique(points_all[:, 0])
1442
+ z = 0
1443
+ ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
1444
+ 0, "int"), np.zeros(0, "int")
1445
+ for stroke in points:
1446
+ stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
1447
+ vr = stroke[:, 1]
1448
+ vc = stroke[:, 2]
1449
+ # get points inside drawn points
1450
+ mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
1451
+ pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
1452
+ axis=-1)[:, np.newaxis, :]
1453
+ mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
1454
+ ar, ac = np.nonzero(mask)
1455
+ ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
1456
+ # get dense outline
1457
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
1458
+ pvc, pvr = contours[-2][0][:,0].T
1459
+ vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
1460
+ # concatenate all points
1461
+ ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
1462
+ # if these pixels are overlapping with another cell, reassign them
1463
+ ioverlap = self.cellpix[z][ar, ac] > 0
1464
+ if (~ioverlap).sum() < 10:
1465
+ print("GUI_ERROR: cell < 10 pixels without overlaps, not drawn")
1466
+ return None
1467
+ elif ioverlap.sum() > 0:
1468
+ ar, ac = ar[~ioverlap], ac[~ioverlap]
1469
+ # compute outline of new mask
1470
+ mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
1471
+ mask[ar - vr.min() + 2, ac - vc.min() + 2] = 1
1472
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
1473
+ cv2.CHAIN_APPROX_NONE)
1474
+ pvc, pvr = contours[-2][0][:,0].T
1475
+ vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
1476
+ ars = np.concatenate((ars, ar), axis=0)
1477
+ acs = np.concatenate((acs, ac), axis=0)
1478
+ vrs = np.concatenate((vrs, vr), axis=0)
1479
+ vcs = np.concatenate((vcs, vc), axis=0)
1480
+
1481
+ self.draw_mask(z, ars, acs, vrs, vcs, color)
1482
+ median.append(np.array([np.median(ars), np.median(acs)]))
1483
+
1484
+ self.zdraw.append(zdraw)
1485
+ d = datetime.datetime.now()
1486
+ self.track_changes.append(
1487
+ [d.strftime("%m/%d/%Y, %H:%M:%S"), "added mask", [ar, ac]])
1488
+ return median
1489
+
1490
+ def draw_mask(self, z, ar, ac, vr, vc, color, idx=None):
1491
+ """ draw single mask using outlines and area """
1492
+ if idx is None:
1493
+ idx = self.ncells + 1
1494
+ self.cellpix[z, vr, vc] = idx
1495
+ self.cellpix[z, ar, ac] = idx
1496
+ self.outpix[z, vr, vc] = idx
1497
+ if self.restore and "upsample" in self.restore:
1498
+ if self.resize:
1499
+ self.cellpix_resize[z, vr, vc] = idx
1500
+ self.cellpix_resize[z, ar, ac] = idx
1501
+ self.outpix_resize[z, vr, vc] = idx
1502
+ self.cellpix_orig[z, (vr / self.ratio).astype(int),
1503
+ (vc / self.ratio).astype(int)] = idx
1504
+ self.cellpix_orig[z, (ar / self.ratio).astype(int),
1505
+ (ac / self.ratio).astype(int)] = idx
1506
+ self.outpix_orig[z, (vr / self.ratio).astype(int),
1507
+ (vc / self.ratio).astype(int)] = idx
1508
+ else:
1509
+ self.cellpix_orig[z, vr, vc] = idx
1510
+ self.cellpix_orig[z, ar, ac] = idx
1511
+ self.outpix_orig[z, vr, vc] = idx
1512
+
1513
+ # get upsampled mask
1514
+ vrr = (vr.copy() * self.ratio).astype(int)
1515
+ vcr = (vc.copy() * self.ratio).astype(int)
1516
+ mask = np.zeros((np.ptp(vrr) + 4, np.ptp(vcr) + 4), np.uint8)
1517
+ pts = np.stack((vcr - vcr.min() + 2, vrr - vrr.min() + 2),
1518
+ axis=-1)[:, np.newaxis, :]
1519
+ mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
1520
+ arr, acr = np.nonzero(mask)
1521
+ arr, acr = arr + vrr.min() - 2, acr + vcr.min() - 2
1522
+ # get dense outline
1523
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
1524
+ cv2.CHAIN_APPROX_NONE)
1525
+ pvc, pvr = contours[-2][0].squeeze().T
1526
+ vrr, vcr = pvr + vrr.min() - 2, pvc + vcr.min() - 2
1527
+ # concatenate all points
1528
+ arr, acr = np.hstack((np.vstack((vrr, vcr)), np.vstack((arr, acr))))
1529
+ self.cellpix_resize[z, vrr, vcr] = idx
1530
+ self.cellpix_resize[z, arr, acr] = idx
1531
+ self.outpix_resize[z, vrr, vcr] = idx
1532
+
1533
+ if z == self.currentZ:
1534
+ self.layerz[ar, ac, :3] = color
1535
+ if self.masksOn:
1536
+ self.layerz[ar, ac, -1] = self.opacity
1537
+ if self.outlinesOn:
1538
+ self.layerz[vr, vc] = np.array(self.outcolor)
1539
+
1540
+ def compute_scale(self):
1541
+ # get diameter from gui
1542
+ diameter = self.segmentation_settings.diameter
1543
+ if not diameter:
1544
+ diameter = 30
1545
+
1546
+ self.pr = int(diameter)
1547
+ self.radii_padding = int(self.pr * 1.25)
1548
+ self.radii = np.zeros((self.Ly + self.radii_padding, self.Lx, 4), np.uint8)
1549
+ yy, xx = disk([self.Ly + self.radii_padding / 2 - 1, self.pr / 2 + 1],
1550
+ self.pr / 2, self.Ly + self.radii_padding, self.Lx)
1551
+ # rgb(150,50,150)
1552
+ self.radii[yy, xx, 0] = 150
1553
+ self.radii[yy, xx, 1] = 50
1554
+ self.radii[yy, xx, 2] = 150
1555
+ self.radii[yy, xx, 3] = 255
1556
+ self.p0.setYRange(0, self.Ly + self.radii_padding)
1557
+ self.p0.setXRange(0, self.Lx)
1558
+
1559
+ def update_scale(self):
1560
+ self.compute_scale()
1561
+ self.scale.setImage(self.radii, autoLevels=False)
1562
+ self.scale.setLevels([0.0, 255.0])
1563
+ self.win.show()
1564
+ self.show()
1565
+
1566
+
1567
+ def draw_layer(self):
1568
+ if self.resize:
1569
+ self.Ly, self.Lx = self.Lyr, self.Lxr
1570
+ else:
1571
+ self.Ly, self.Lx = self.Ly0, self.Lx0
1572
+
1573
+ if self.masksOn or self.outlinesOn:
1574
+ if self.restore and "upsample" in self.restore:
1575
+ if self.resize:
1576
+ self.cellpix = self.cellpix_resize.copy()
1577
+ self.outpix = self.outpix_resize.copy()
1578
+ else:
1579
+ self.cellpix = self.cellpix_orig.copy()
1580
+ self.outpix = self.outpix_orig.copy()
1581
+
1582
+ self.layerz = np.zeros((self.Ly, self.Lx, 4), np.uint8)
1583
+ if self.masksOn:
1584
+ self.layerz[..., :3] = self.cellcolors[self.cellpix[self.currentZ], :]
1585
+ self.layerz[..., 3] = self.opacity * (self.cellpix[self.currentZ]
1586
+ > 0).astype(np.uint8)
1587
+ if self.selected > 0:
1588
+ self.layerz[self.cellpix[self.currentZ] == self.selected] = np.array(
1589
+ [255, 255, 255, self.opacity])
1590
+ cZ = self.currentZ
1591
+ stroke_z = np.array([s[0][0] for s in self.strokes])
1592
+ inZ = np.nonzero(stroke_z == cZ)[0]
1593
+ if len(inZ) > 0:
1594
+ for i in inZ:
1595
+ stroke = np.array(self.strokes[i])
1596
+ self.layerz[stroke[:, 1], stroke[:,
1597
+ 2]] = np.array([255, 0, 255, 100])
1598
+ else:
1599
+ self.layerz[..., 3] = 0
1600
+
1601
+ if self.outlinesOn:
1602
+ self.layerz[self.outpix[self.currentZ] > 0] = np.array(
1603
+ self.outcolor).astype(np.uint8)
1604
+
1605
+
1606
+ def set_normalize_params(self, normalize_params):
1607
+ from cellpose.models import normalize_default
1608
+ if self.restore != "filter":
1609
+ keys = list(normalize_params.keys()).copy()
1610
+ for key in keys:
1611
+ if key != "percentile":
1612
+ normalize_params[key] = normalize_default[key]
1613
+ normalize_params = {**normalize_default, **normalize_params}
1614
+ out = self.check_filter_params(normalize_params["sharpen_radius"],
1615
+ normalize_params["smooth_radius"],
1616
+ normalize_params["tile_norm_blocksize"],
1617
+ normalize_params["tile_norm_smooth3D"],
1618
+ normalize_params["norm3D"],
1619
+ normalize_params["invert"])
1620
+
1621
+
1622
+ def check_filter_params(self, sharpen, smooth, tile_norm, smooth3D, norm3D, invert):
1623
+ tile_norm = 0 if tile_norm < 0 else tile_norm
1624
+ sharpen = 0 if sharpen < 0 else sharpen
1625
+ smooth = 0 if smooth < 0 else smooth
1626
+ smooth3D = 0 if smooth3D < 0 else smooth3D
1627
+ norm3D = bool(norm3D)
1628
+ invert = bool(invert)
1629
+ if tile_norm > self.Ly and tile_norm > self.Lx:
1630
+ print(
1631
+ "GUI_ERROR: tile size (tile_norm) bigger than both image dimensions, disabling"
1632
+ )
1633
+ tile_norm = 0
1634
+ self.filt_edits[0].setText(str(sharpen))
1635
+ self.filt_edits[1].setText(str(smooth))
1636
+ self.filt_edits[2].setText(str(tile_norm))
1637
+ self.filt_edits[3].setText(str(smooth3D))
1638
+ self.norm3D_cb.setChecked(norm3D)
1639
+ return sharpen, smooth, tile_norm, smooth3D, norm3D, invert
1640
+
1641
+ def get_normalize_params(self):
1642
+ percentile = [
1643
+ self.segmentation_settings.low_percentile,
1644
+ self.segmentation_settings.high_percentile,
1645
+ ]
1646
+ normalize_params = {"percentile": percentile}
1647
+ norm3D = self.norm3D_cb.isChecked()
1648
+ normalize_params["norm3D"] = norm3D
1649
+ sharpen = float(self.filt_edits[0].text())
1650
+ smooth = float(self.filt_edits[1].text())
1651
+ tile_norm = float(self.filt_edits[2].text())
1652
+ smooth3D = float(self.filt_edits[3].text())
1653
+ invert = False
1654
+ out = self.check_filter_params(sharpen, smooth, tile_norm, smooth3D, norm3D,
1655
+ invert)
1656
+ sharpen, smooth, tile_norm, smooth3D, norm3D, invert = out
1657
+ normalize_params["sharpen_radius"] = sharpen
1658
+ normalize_params["smooth_radius"] = smooth
1659
+ normalize_params["tile_norm_blocksize"] = tile_norm
1660
+ normalize_params["tile_norm_smooth3D"] = smooth3D
1661
+ normalize_params["invert"] = invert
1662
+
1663
+ from cellpose.models import normalize_default
1664
+ normalize_params = {**normalize_default, **normalize_params}
1665
+
1666
+ return normalize_params
1667
+
1668
+ def compute_saturation_if_checked(self):
1669
+ if self.autobtn.isChecked():
1670
+ self.compute_saturation()
1671
+
1672
+ def compute_saturation(self, return_img=False):
1673
+ norm = self.get_normalize_params()
1674
+ print(norm)
1675
+ sharpen, smooth = norm["sharpen_radius"], norm["smooth_radius"]
1676
+ percentile = norm["percentile"]
1677
+ tile_norm = norm["tile_norm_blocksize"]
1678
+ invert = norm["invert"]
1679
+ norm3D = norm["norm3D"]
1680
+ smooth3D = norm["tile_norm_smooth3D"]
1681
+ tile_norm = norm["tile_norm_blocksize"]
1682
+
1683
+ if sharpen > 0 or smooth > 0 or tile_norm > 0:
1684
+ img_norm = self.stack.copy()
1685
+ else:
1686
+ img_norm = self.stack
1687
+
1688
+ if sharpen > 0 or smooth > 0 or tile_norm > 0:
1689
+ self.restore = "filter"
1690
+ print(
1691
+ "GUI_INFO: computing filtered image because sharpen > 0 or tile_norm > 0"
1692
+ )
1693
+ print(
1694
+ "GUI_WARNING: will use memory to create filtered image -- make sure to have RAM for this"
1695
+ )
1696
+ img_norm = self.stack.copy()
1697
+ if sharpen > 0 or smooth > 0:
1698
+ img_norm = smooth_sharpen_img(self.stack, sharpen_radius=sharpen,
1699
+ smooth_radius=smooth)
1700
+
1701
+ if tile_norm > 0:
1702
+ img_norm = normalize99_tile(img_norm, blocksize=tile_norm,
1703
+ lower=percentile[0], upper=percentile[1],
1704
+ smooth3D=smooth3D, norm3D=norm3D)
1705
+ # convert to 0->255
1706
+ img_norm_min = img_norm.min()
1707
+ img_norm_max = img_norm.max()
1708
+ for c in range(img_norm.shape[-1]):
1709
+ if np.ptp(img_norm[..., c]) > 1e-3:
1710
+ img_norm[..., c] -= img_norm_min
1711
+ img_norm[..., c] /= (img_norm_max - img_norm_min)
1712
+ img_norm *= 255
1713
+ self.stack_filtered = img_norm
1714
+ self.ViewDropDown.model().item(self.ViewDropDown.count() -
1715
+ 1).setEnabled(True)
1716
+ self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1)
1717
+ else:
1718
+ img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered
1719
+
1720
+ if self.autobtn.isChecked():
1721
+ self.saturation = []
1722
+ for c in range(img_norm.shape[-1]):
1723
+ self.saturation.append([])
1724
+ if np.ptp(img_norm[..., c]) > 1e-3:
1725
+ if norm3D:
1726
+ x01 = np.percentile(img_norm[..., c], percentile[0])
1727
+ x99 = np.percentile(img_norm[..., c], percentile[1])
1728
+ if invert:
1729
+ x01i = 255. - x99
1730
+ x99i = 255. - x01
1731
+ x01, x99 = x01i, x99i
1732
+ for n in range(self.NZ):
1733
+ self.saturation[-1].append([x01, x99])
1734
+ else:
1735
+ for z in range(self.NZ):
1736
+ if self.NZ > 1:
1737
+ x01 = np.percentile(img_norm[z, :, :, c], percentile[0])
1738
+ x99 = np.percentile(img_norm[z, :, :, c], percentile[1])
1739
+ else:
1740
+ x01 = np.percentile(img_norm[..., c], percentile[0])
1741
+ x99 = np.percentile(img_norm[..., c], percentile[1])
1742
+ if invert:
1743
+ x01i = 255. - x99
1744
+ x99i = 255. - x01
1745
+ x01, x99 = x01i, x99i
1746
+ self.saturation[-1].append([x01, x99])
1747
+ else:
1748
+ for n in range(self.NZ):
1749
+ self.saturation[-1].append([0, 255.])
1750
+ print(self.saturation[2][self.currentZ])
1751
+
1752
+ if img_norm.shape[-1] == 1:
1753
+ self.saturation.append(self.saturation[0])
1754
+ self.saturation.append(self.saturation[0])
1755
+
1756
+ # self.autobtn.setChecked(True)
1757
+ self.update_plot()
1758
+
1759
+
1760
+ def get_model_path(self, custom=False):
1761
+ if custom:
1762
+ self.current_model = self.ModelChooseC.currentText()
1763
+ self.current_model_path = os.fspath(
1764
+ models.MODEL_DIR.joinpath(self.current_model))
1765
+ else:
1766
+ self.current_model = "cpsam"
1767
+ self.current_model_path = models.model_path(self.current_model)
1768
+
1769
+ def initialize_model(self, model_name=None, custom=False):
1770
+ if model_name is None or custom:
1771
+ self.get_model_path(custom=custom)
1772
+ if not os.path.exists(self.current_model_path):
1773
+ raise ValueError("need to specify model (use dropdown)")
1774
+
1775
+ if model_name is None or not isinstance(model_name, str):
1776
+ self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
1777
+ pretrained_model=self.current_model_path)
1778
+ else:
1779
+ self.current_model = model_name
1780
+ self.current_model_path = os.fspath(
1781
+ models.MODEL_DIR.joinpath(self.current_model))
1782
+
1783
+ self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
1784
+ pretrained_model=self.current_model)
1785
+
1786
+ def add_model(self):
1787
+ io._add_model(self)
1788
+ return
1789
+
1790
+ def remove_model(self):
1791
+ io._remove_model(self)
1792
+ return
1793
+
1794
+ def new_model(self):
1795
+ if self.NZ != 1:
1796
+ print("ERROR: cannot train model on 3D data")
1797
+ return
1798
+
1799
+ # train model
1800
+ image_names = self.get_files()[0]
1801
+ self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set(
1802
+ image_names)
1803
+ TW = guiparts.TrainWindow(self, models.MODEL_NAMES)
1804
+ train = TW.exec_()
1805
+ if train:
1806
+ self.logger.info(
1807
+ f"training with {[os.path.split(f)[1] for f in self.train_files]}")
1808
+ self.train_model(restore=restore, normalize_params=normalize_params)
1809
+ else:
1810
+ print("GUI_INFO: training cancelled")
1811
+
1812
+ def train_model(self, restore=None, normalize_params=None):
1813
+ from cellpose.models import normalize_default
1814
+ if normalize_params is None:
1815
+ normalize_params = copy.deepcopy(normalize_default)
1816
+ model_type = models.MODEL_NAMES[self.training_params["model_index"]]
1817
+ self.logger.info(f"training new model starting at model {model_type}")
1818
+ self.current_model = model_type
1819
+
1820
+ self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
1821
+ model_type=model_type)
1822
+ save_path = os.path.dirname(self.filename)
1823
+
1824
+ print("GUI_INFO: name of new model: " + self.training_params["model_name"])
1825
+ self.new_model_path, train_losses = train.train_seg(
1826
+ self.model.net, train_data=self.train_data, train_labels=self.train_labels,
1827
+ normalize=normalize_params, min_train_masks=0,
1828
+ save_path=save_path, nimg_per_epoch=max(2, len(self.train_data)),
1829
+ learning_rate=self.training_params["learning_rate"],
1830
+ weight_decay=self.training_params["weight_decay"],
1831
+ n_epochs=self.training_params["n_epochs"],
1832
+ model_name=self.training_params["model_name"])[:2]
1833
+ # save train losses
1834
+ np.save(str(self.new_model_path) + "_train_losses.npy", train_losses)
1835
+ # run model on next image
1836
+ io._add_model(self, self.new_model_path)
1837
+ diam_labels = self.model.net.diam_labels.item() #.copy()
1838
+ self.new_model_ind = len(self.model_strings)
1839
+ self.autorun = True
1840
+ self.clear_all()
1841
+ self.restore = restore
1842
+ self.set_normalize_params(normalize_params)
1843
+ self.get_next_image(load_seg=False)
1844
+
1845
+ self.compute_segmentation(custom=True)
1846
+ self.logger.info(
1847
+ f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!"
1848
+ )
1849
+
1850
+
1851
+ def compute_cprob(self):
1852
+ if self.recompute_masks:
1853
+ flow_threshold = self.segmentation_settings.flow_threshold
1854
+ cellprob_threshold = self.segmentation_settings.cellprob_threshold
1855
+ niter = self.segmentation_settings.niter
1856
+ min_size = int(self.min_size.text()) if not isinstance(
1857
+ self.min_size, int) else self.min_size
1858
+
1859
+ self.logger.info(
1860
+ "computing masks with cell prob=%0.3f, flow error threshold=%0.3f" %
1861
+ (cellprob_threshold, flow_threshold))
1862
+
1863
+ try:
1864
+ dP = self.flows[2].squeeze()
1865
+ cellprob = self.flows[3].squeeze()
1866
+ except IndexError:
1867
+ self.logger.error("Flows don't exist, try running model again.")
1868
+ return
1869
+
1870
+ maski = dynamics.resize_and_compute_masks(
1871
+ dP=dP,
1872
+ cellprob=cellprob,
1873
+ niter=niter,
1874
+ do_3D=self.load_3D,
1875
+ min_size=min_size,
1876
+ # max_size_fraction=min_size_fraction, # Leave as default
1877
+ cellprob_threshold=cellprob_threshold,
1878
+ flow_threshold=flow_threshold)
1879
+
1880
+ self.masksOn = True
1881
+ if not self.OCheckBox.isChecked():
1882
+ self.MCheckBox.setChecked(True)
1883
+ if maski.ndim < 3:
1884
+ maski = maski[np.newaxis, ...]
1885
+ self.logger.info("%d cells found" % (len(np.unique(maski)[1:])))
1886
+ io._masks_to_gui(self, maski, outlines=None)
1887
+ self.show()
1888
+
1889
+
1890
+ def compute_segmentation(self, custom=False, model_name=None, load_model=True):
1891
+ self.progress.setValue(0)
1892
+ try:
1893
+ tic = time.time()
1894
+ self.clear_all()
1895
+ self.flows = [[], [], []]
1896
+ if load_model:
1897
+ self.initialize_model(model_name=model_name, custom=custom)
1898
+ self.progress.setValue(10)
1899
+ do_3D = self.load_3D
1900
+ stitch_threshold = float(self.stitch_threshold.text()) if not isinstance(
1901
+ self.stitch_threshold, float) else self.stitch_threshold
1902
+ anisotropy = float(self.anisotropy.text()) if not isinstance(
1903
+ self.anisotropy, float) else self.anisotropy
1904
+ flow3D_smooth = float(self.flow3D_smooth.text()) if not isinstance(
1905
+ self.flow3D_smooth, float) else self.flow3D_smooth
1906
+ min_size = int(self.min_size.text()) if not isinstance(
1907
+ self.min_size, int) else self.min_size
1908
+
1909
+ do_3D = False if stitch_threshold > 0. else do_3D
1910
+
1911
+ if self.restore == "filter":
1912
+ data = self.stack_filtered.copy().squeeze()
1913
+ else:
1914
+ data = self.stack.copy().squeeze()
1915
+
1916
+ flow_threshold = self.segmentation_settings.flow_threshold
1917
+ cellprob_threshold = self.segmentation_settings.cellprob_threshold
1918
+ diameter = self.segmentation_settings.diameter
1919
+ niter = self.segmentation_settings.niter
1920
+
1921
+ normalize_params = self.get_normalize_params()
1922
+ print(normalize_params)
1923
+ try:
1924
+ masks, flows = self.model.eval(
1925
+ data,
1926
+ diameter=diameter,
1927
+ cellprob_threshold=cellprob_threshold,
1928
+ flow_threshold=flow_threshold, do_3D=do_3D, niter=niter,
1929
+ normalize=normalize_params, stitch_threshold=stitch_threshold,
1930
+ anisotropy=anisotropy, flow3D_smooth=flow3D_smooth,
1931
+ min_size=min_size, channel_axis=-1,
1932
+ progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2]
1933
+ except Exception as e:
1934
+ print("NET ERROR: %s" % e)
1935
+ self.progress.setValue(0)
1936
+ return
1937
+
1938
+ self.progress.setValue(75)
1939
+
1940
+ # convert flows to uint8 and resize to original image size
1941
+ flows_new = []
1942
+ flows_new.append(flows[0].copy()) # RGB flow
1943
+ flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) *
1944
+ 255).astype("uint8")) # cellprob
1945
+ flows_new.append(flows[1].copy()) # XY flows
1946
+ flows_new.append(flows[2].copy()) # original cellprob
1947
+
1948
+ if self.load_3D:
1949
+ if stitch_threshold == 0.:
1950
+ flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8"))
1951
+ else:
1952
+ flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8"))
1953
+
1954
+ if not self.load_3D:
1955
+ if self.restore and "upsample" in self.restore:
1956
+ self.Ly, self.Lx = self.Lyr, self.Lxr
1957
+
1958
+ if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx):
1959
+ self.flows = []
1960
+ for j in range(len(flows_new)):
1961
+ self.flows.append(
1962
+ resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx,
1963
+ interpolation=cv2.INTER_NEAREST))
1964
+ else:
1965
+ self.flows = flows_new
1966
+ else:
1967
+ self.flows = []
1968
+ Lz, Ly, Lx = self.NZ, self.Ly, self.Lx
1969
+ Lz0, Ly0, Lx0 = flows_new[0].shape[:3]
1970
+ print("GUI_INFO: resizing flows to original image size")
1971
+ for j in range(len(flows_new)):
1972
+ flow0 = flows_new[j]
1973
+ if Ly0 != Ly:
1974
+ flow0 = resize_image(flow0, Ly=Ly, Lx=Lx,
1975
+ no_channels=flow0.ndim==3,
1976
+ interpolation=cv2.INTER_NEAREST)
1977
+ if Lz0 != Lz:
1978
+ flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1),
1979
+ Ly=Lz, Lx=Lx,
1980
+ no_channels=flow0.ndim==3,
1981
+ interpolation=cv2.INTER_NEAREST), 0, 1)
1982
+ self.flows.append(flow0)
1983
+
1984
+ # add first axis
1985
+ if self.NZ == 1:
1986
+ masks = masks[np.newaxis, ...]
1987
+ self.flows = [
1988
+ self.flows[n][np.newaxis, ...] for n in range(len(self.flows))
1989
+ ]
1990
+
1991
+ self.logger.info("%d cells found with model in %0.3f sec" %
1992
+ (len(np.unique(masks)[1:]), time.time() - tic))
1993
+ self.progress.setValue(80)
1994
+ z = 0
1995
+
1996
+ io._masks_to_gui(self, masks, outlines=None)
1997
+ self.masksOn = True
1998
+ self.MCheckBox.setChecked(True)
1999
+ self.progress.setValue(100)
2000
+ if self.restore != "filter" and self.restore is not None and self.autobtn.isChecked():
2001
+ self.compute_saturation()
2002
+ if not do_3D and not stitch_threshold > 0:
2003
+ self.recompute_masks = True
2004
+ else:
2005
+ self.recompute_masks = False
2006
+ except Exception as e:
2007
+ print("ERROR: %s" % e)
models/seg_post_model/cellpose/gui/gui3d.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
3
+ """
4
+
5
+ import sys, pathlib, warnings
6
+
7
+ from qtpy import QtGui, QtCore
8
+ from qtpy.QtWidgets import QApplication, QScrollBar, QCheckBox, QLabel, QLineEdit
9
+ import pyqtgraph as pg
10
+
11
+ import numpy as np
12
+ from scipy.stats import mode
13
+ import cv2
14
+
15
+ from . import guiparts, io
16
+ from ..utils import download_url_to_file, masks_to_outlines
17
+ from .gui import MainW
18
+
19
+ try:
20
+ import matplotlib.pyplot as plt
21
+ MATPLOTLIB = True
22
+ except:
23
+ MATPLOTLIB = False
24
+
25
+
26
+ def avg3d(C):
27
+ """ smooth value of c across nearby points
28
+ (c is center of grid directly below point)
29
+ b -- a -- b
30
+ a -- c -- a
31
+ b -- a -- b
32
+ """
33
+ Ly, Lx = C.shape
34
+ # pad T by 2
35
+ T = np.zeros((Ly + 2, Lx + 2), "float32")
36
+ M = np.zeros((Ly, Lx), "float32")
37
+ T[1:-1, 1:-1] = C.copy()
38
+ y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
39
+ indexing="ij")
40
+ y += 1
41
+ x += 1
42
+ a = 1. / 2 #/(z**2 + 1)**0.5
43
+ b = 1. / (1 + 2**0.5) #(z**2 + 2)**0.5
44
+ c = 1.
45
+ M = (b * T[y - 1, x - 1] + a * T[y - 1, x] + b * T[y - 1, x + 1] + a * T[y, x - 1] +
46
+ c * T[y, x] + a * T[y, x + 1] + b * T[y + 1, x - 1] + a * T[y + 1, x] +
47
+ b * T[y + 1, x + 1])
48
+ M /= 4 * a + 4 * b + c
49
+ return M
50
+
51
+
52
+ def interpZ(mask, zdraw):
53
+ """ find nearby planes and average their values using grid of points
54
+ zfill is in ascending order
55
+ """
56
+ ifill = np.ones(mask.shape[0], "bool")
57
+ zall = np.arange(0, mask.shape[0], 1, int)
58
+ ifill[zdraw] = False
59
+ zfill = zall[ifill]
60
+ zlower = zdraw[np.searchsorted(zdraw, zfill, side="left") - 1]
61
+ zupper = zdraw[np.searchsorted(zdraw, zfill, side="right")]
62
+ for k, z in enumerate(zfill):
63
+ Z = zupper[k] - zlower[k]
64
+ zl = (z - zlower[k]) / Z
65
+ plower = avg3d(mask[zlower[k]]) * (1 - zl)
66
+ pupper = avg3d(mask[zupper[k]]) * zl
67
+ mask[z] = (plower + pupper) > 0.33
68
+ return mask, zfill
69
+
70
+
71
+ def run(image=None):
72
+ from ..io import logger_setup
73
+ logger, log_file = logger_setup()
74
+ # Always start by initializing Qt (only once per application)
75
+ warnings.filterwarnings("ignore")
76
+ app = QApplication(sys.argv)
77
+ icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
78
+ guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png")
79
+ style_path = pathlib.Path.home().joinpath(".cellpose", "style_choice.npy")
80
+ if not icon_path.is_file():
81
+ cp_dir = pathlib.Path.home().joinpath(".cellpose")
82
+ cp_dir.mkdir(exist_ok=True)
83
+ print("downloading logo")
84
+ download_url_to_file(
85
+ "https://www.cellpose.org/static/images/cellpose_transparent.png",
86
+ icon_path, progress=True)
87
+ if not guip_path.is_file():
88
+ print("downloading help window image")
89
+ download_url_to_file("https://www.cellpose.org/static/images/cellpose_gui.png",
90
+ guip_path, progress=True)
91
+ icon_path = str(icon_path.resolve())
92
+ app_icon = QtGui.QIcon()
93
+ app_icon.addFile(icon_path, QtCore.QSize(16, 16))
94
+ app_icon.addFile(icon_path, QtCore.QSize(24, 24))
95
+ app_icon.addFile(icon_path, QtCore.QSize(32, 32))
96
+ app_icon.addFile(icon_path, QtCore.QSize(48, 48))
97
+ app_icon.addFile(icon_path, QtCore.QSize(64, 64))
98
+ app_icon.addFile(icon_path, QtCore.QSize(256, 256))
99
+ app.setWindowIcon(app_icon)
100
+ app.setStyle("Fusion")
101
+ app.setPalette(guiparts.DarkPalette())
102
+ MainW_3d(image=image, logger=logger)
103
+ ret = app.exec_()
104
+ sys.exit(ret)
105
+
106
+
107
+ class MainW_3d(MainW):
108
+
109
+ def __init__(self, image=None, logger=None):
110
+ # MainW init
111
+ MainW.__init__(self, image=image, logger=logger)
112
+
113
+ # add gradZ view
114
+ self.ViewDropDown.insertItem(3, "gradZ")
115
+
116
+ # turn off single stroke
117
+ self.SCheckBox.setChecked(False)
118
+
119
+ ### add orthoviews and z-bar
120
+ # ortho crosshair lines
121
+ self.vLine = pg.InfiniteLine(angle=90, movable=False)
122
+ self.hLine = pg.InfiniteLine(angle=0, movable=False)
123
+ self.vLineOrtho = [
124
+ pg.InfiniteLine(angle=90, movable=False),
125
+ pg.InfiniteLine(angle=90, movable=False)
126
+ ]
127
+ self.hLineOrtho = [
128
+ pg.InfiniteLine(angle=0, movable=False),
129
+ pg.InfiniteLine(angle=0, movable=False)
130
+ ]
131
+ self.make_orthoviews()
132
+
133
+ # z scrollbar underneath
134
+ self.scroll = QScrollBar(QtCore.Qt.Horizontal)
135
+ self.scroll.setMaximum(10)
136
+ self.scroll.valueChanged.connect(self.move_in_Z)
137
+ self.lmain.addWidget(self.scroll, 40, 9, 1, 30)
138
+
139
+ b = 22
140
+
141
+ label = QLabel("stitch\nthreshold:")
142
+ label.setToolTip(
143
+ "for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
144
+ )
145
+ label.setFont(self.medfont)
146
+ self.segBoxG.addWidget(label, b, 0, 1, 4)
147
+ self.stitch_threshold = QLineEdit()
148
+ self.stitch_threshold.setText("0.0")
149
+ self.stitch_threshold.setFixedWidth(30)
150
+ self.stitch_threshold.setFont(self.medfont)
151
+ self.stitch_threshold.setToolTip(
152
+ "for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
153
+ )
154
+ self.segBoxG.addWidget(self.stitch_threshold, b, 3, 1, 1)
155
+
156
+ label = QLabel("flow3D\nsmooth:")
157
+ label.setToolTip(
158
+ "for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)"
159
+ )
160
+ label.setFont(self.medfont)
161
+ self.segBoxG.addWidget(label, b, 4, 1, 3)
162
+ self.flow3D_smooth = QLineEdit()
163
+ self.flow3D_smooth.setText("0.0")
164
+ self.flow3D_smooth.setFixedWidth(30)
165
+ self.flow3D_smooth.setFont(self.medfont)
166
+ self.flow3D_smooth.setToolTip(
167
+ "for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)"
168
+ )
169
+ self.segBoxG.addWidget(self.flow3D_smooth, b, 7, 1, 1)
170
+
171
+ b+=1
172
+ label = QLabel("anisotropy:")
173
+ label.setToolTip(
174
+ "for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
175
+ )
176
+ label.setFont(self.medfont)
177
+ self.segBoxG.addWidget(label, b, 0, 1, 3)
178
+ self.anisotropy = QLineEdit()
179
+ self.anisotropy.setText("1.0")
180
+ self.anisotropy.setFixedWidth(30)
181
+ self.anisotropy.setFont(self.medfont)
182
+ self.anisotropy.setToolTip(
183
+ "for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
184
+ )
185
+ self.segBoxG.addWidget(self.anisotropy, b, 3, 1, 1)
186
+
187
+ b+=1
188
+ label = QLabel("min\nsize:")
189
+ label.setToolTip(
190
+ "all masks less than this size in pixels (volume) will be removed"
191
+ )
192
+ label.setFont(self.medfont)
193
+ self.segBoxG.addWidget(label, b, 0, 1, 4)
194
+ self.min_size = QLineEdit()
195
+ self.min_size.setText("15")
196
+ self.min_size.setFixedWidth(50)
197
+ self.min_size.setFont(self.medfont)
198
+ self.min_size.setToolTip(
199
+ "all masks less than this size in pixels (volume) will be removed"
200
+ )
201
+ self.segBoxG.addWidget(self.min_size, b, 3, 1, 1)
202
+
203
+ b += 1
204
+ self.orthobtn = QCheckBox("ortho")
205
+ self.orthobtn.setToolTip("activate orthoviews with 3D image")
206
+ self.orthobtn.setFont(self.medfont)
207
+ self.orthobtn.setChecked(False)
208
+ self.l0.addWidget(self.orthobtn, b, 0, 1, 2)
209
+ self.orthobtn.toggled.connect(self.toggle_ortho)
210
+
211
+ label = QLabel("dz:")
212
+ label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
213
+ label.setFont(self.medfont)
214
+ self.l0.addWidget(label, b, 2, 1, 1)
215
+ self.dz = 10
216
+ self.dzedit = QLineEdit()
217
+ self.dzedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
218
+ self.dzedit.setText(str(self.dz))
219
+ self.dzedit.returnPressed.connect(self.update_ortho)
220
+ self.dzedit.setFixedWidth(40)
221
+ self.dzedit.setFont(self.medfont)
222
+ self.l0.addWidget(self.dzedit, b, 3, 1, 2)
223
+
224
+ label = QLabel("z-aspect:")
225
+ label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
226
+ label.setFont(self.medfont)
227
+ self.l0.addWidget(label, b, 5, 1, 2)
228
+ self.zaspect = 1.0
229
+ self.zaspectedit = QLineEdit()
230
+ self.zaspectedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
231
+ self.zaspectedit.setText(str(self.zaspect))
232
+ self.zaspectedit.returnPressed.connect(self.update_ortho)
233
+ self.zaspectedit.setFixedWidth(40)
234
+ self.zaspectedit.setFont(self.medfont)
235
+ self.l0.addWidget(self.zaspectedit, b, 7, 1, 2)
236
+
237
+ b += 1
238
+ # add z position underneath
239
+ self.currentZ = 0
240
+ label = QLabel("Z:")
241
+ label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
242
+ self.l0.addWidget(label, b, 5, 1, 2)
243
+ self.zpos = QLineEdit()
244
+ self.zpos.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
245
+ self.zpos.setText(str(self.currentZ))
246
+ self.zpos.returnPressed.connect(self.update_ztext)
247
+ self.zpos.setFixedWidth(40)
248
+ self.zpos.setFont(self.medfont)
249
+ self.l0.addWidget(self.zpos, b, 7, 1, 2)
250
+
251
+ # if called with image, load it
252
+ if image is not None:
253
+ self.filename = image
254
+ io._load_image(self, self.filename, load_3D=True)
255
+
256
+ self.load_3D = True
257
+
258
+ def add_mask(self, points=None, color=(100, 200, 50), dense=True):
259
+ # points is list of strokes
260
+
261
+ points_all = np.concatenate(points, axis=0)
262
+
263
+ # loop over z values
264
+ median = []
265
+ zdraw = np.unique(points_all[:, 0])
266
+ zrange = np.arange(zdraw.min(), zdraw.max() + 1, 1, int)
267
+ zmin = zdraw.min()
268
+ pix = np.zeros((2, 0), "uint16")
269
+ mall = np.zeros((len(zrange), self.Ly, self.Lx), "bool")
270
+ k = 0
271
+ for z in zdraw:
272
+ ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
273
+ 0, "int"), np.zeros(0, "int")
274
+ for stroke in points:
275
+ stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
276
+ iz = stroke[:, 0] == z
277
+ vr = stroke[iz, 1]
278
+ vc = stroke[iz, 2]
279
+ if iz.sum() > 0:
280
+ # get points inside drawn points
281
+ mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8")
282
+ pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
283
+ axis=-1)[:, np.newaxis, :]
284
+ mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
285
+ ar, ac = np.nonzero(mask)
286
+ ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
287
+ # get dense outline
288
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
289
+ cv2.CHAIN_APPROX_NONE)
290
+ pvc, pvr = contours[-2][0].squeeze().T
291
+ vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
292
+ # concatenate all points
293
+ ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
294
+ # if these pixels are overlapping with another cell, reassign them
295
+ ioverlap = self.cellpix[z][ar, ac] > 0
296
+ if (~ioverlap).sum() < 8:
297
+ print("ERROR: cell too small without overlaps, not drawn")
298
+ return None
299
+ elif ioverlap.sum() > 0:
300
+ ar, ac = ar[~ioverlap], ac[~ioverlap]
301
+ # compute outline of new mask
302
+ mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8")
303
+ mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
304
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
305
+ cv2.CHAIN_APPROX_NONE)
306
+ pvc, pvr = contours[-2][0].squeeze().T
307
+ vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
308
+ ars = np.concatenate((ars, ar), axis=0)
309
+ acs = np.concatenate((acs, ac), axis=0)
310
+ vrs = np.concatenate((vrs, vr), axis=0)
311
+ vcs = np.concatenate((vcs, vc), axis=0)
312
+ self.draw_mask(z, ars, acs, vrs, vcs, color)
313
+
314
+ median.append(np.array([np.median(ars), np.median(acs)]))
315
+ mall[z - zmin, ars, acs] = True
316
+ pix = np.append(pix, np.vstack((ars, acs)), axis=-1)
317
+
318
+ mall = mall[:, pix[0].min():pix[0].max() + 1,
319
+ pix[1].min():pix[1].max() + 1].astype("float32")
320
+ ymin, xmin = pix[0].min(), pix[1].min()
321
+ if len(zdraw) > 1:
322
+ mall, zfill = interpZ(mall, zdraw - zmin)
323
+ for z in zfill:
324
+ mask = mall[z].copy()
325
+ ar, ac = np.nonzero(mask)
326
+ ioverlap = self.cellpix[z + zmin][ar + ymin, ac + xmin] > 0
327
+ if (~ioverlap).sum() < 5:
328
+ print("WARNING: stroke on plane %d not included due to overlaps" %
329
+ z)
330
+ elif ioverlap.sum() > 0:
331
+ mask[ar[ioverlap], ac[ioverlap]] = 0
332
+ ar, ac = ar[~ioverlap], ac[~ioverlap]
333
+ # compute outline of mask
334
+ outlines = masks_to_outlines(mask)
335
+ vr, vc = np.nonzero(outlines)
336
+ vr, vc = vr + ymin, vc + xmin
337
+ ar, ac = ar + ymin, ac + xmin
338
+ self.draw_mask(z + zmin, ar, ac, vr, vc, color)
339
+
340
+ self.zdraw.append(zdraw)
341
+
342
+ return median
343
+
344
+ def move_in_Z(self):
345
+ if self.loaded:
346
+ self.currentZ = min(self.NZ, max(0, int(self.scroll.value())))
347
+ self.zpos.setText(str(self.currentZ))
348
+ self.update_plot()
349
+ self.draw_layer()
350
+ self.update_layer()
351
+
352
+ def make_orthoviews(self):
353
+ self.pOrtho, self.imgOrtho, self.layerOrtho = [], [], []
354
+ for j in range(2):
355
+ self.pOrtho.append(
356
+ pg.ViewBox(lockAspect=True, name=f"plotOrtho{j}",
357
+ border=[100, 100, 100], invertY=True, enableMouse=False))
358
+ self.pOrtho[j].setMenuEnabled(False)
359
+
360
+ self.imgOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self))
361
+ self.imgOrtho[j].autoDownsample = False
362
+
363
+ self.layerOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self))
364
+ self.layerOrtho[j].setLevels([0., 255.])
365
+
366
+ #self.pOrtho[j].scene().contextMenuItem = self.pOrtho[j]
367
+ self.pOrtho[j].addItem(self.imgOrtho[j])
368
+ self.pOrtho[j].addItem(self.layerOrtho[j])
369
+ self.pOrtho[j].addItem(self.vLineOrtho[j], ignoreBounds=False)
370
+ self.pOrtho[j].addItem(self.hLineOrtho[j], ignoreBounds=False)
371
+
372
+ self.pOrtho[0].linkView(self.pOrtho[0].YAxis, self.p0)
373
+ self.pOrtho[1].linkView(self.pOrtho[1].XAxis, self.p0)
374
+
375
+ def add_orthoviews(self):
376
+ self.yortho = self.Ly // 2
377
+ self.xortho = self.Lx // 2
378
+ if self.NZ > 1:
379
+ self.update_ortho()
380
+
381
+ self.win.addItem(self.pOrtho[0], 0, 1, rowspan=1, colspan=1)
382
+ self.win.addItem(self.pOrtho[1], 1, 0, rowspan=1, colspan=1)
383
+
384
+ qGraphicsGridLayout = self.win.ci.layout
385
+ qGraphicsGridLayout.setColumnStretchFactor(0, 2)
386
+ qGraphicsGridLayout.setColumnStretchFactor(1, 1)
387
+ qGraphicsGridLayout.setRowStretchFactor(0, 2)
388
+ qGraphicsGridLayout.setRowStretchFactor(1, 1)
389
+
390
+ self.pOrtho[0].setYRange(0, self.Lx)
391
+ self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
392
+ self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
393
+ self.pOrtho[1].setXRange(0, self.Ly)
394
+
395
+ self.p0.addItem(self.vLine, ignoreBounds=False)
396
+ self.p0.addItem(self.hLine, ignoreBounds=False)
397
+ self.p0.setYRange(0, self.Lx)
398
+ self.p0.setXRange(0, self.Ly)
399
+
400
+ self.win.show()
401
+ self.show()
402
+
403
+ def remove_orthoviews(self):
404
+ self.win.removeItem(self.pOrtho[0])
405
+ self.win.removeItem(self.pOrtho[1])
406
+ self.p0.removeItem(self.vLine)
407
+ self.p0.removeItem(self.hLine)
408
+ self.win.show()
409
+ self.show()
410
+
411
+ def update_crosshairs(self):
412
+ self.yortho = min(self.Ly - 1, max(0, int(self.yortho)))
413
+ self.xortho = min(self.Lx - 1, max(0, int(self.xortho)))
414
+ self.vLine.setPos(self.xortho)
415
+ self.hLine.setPos(self.yortho)
416
+ self.vLineOrtho[1].setPos(self.xortho)
417
+ self.hLineOrtho[1].setPos(self.zc)
418
+ self.vLineOrtho[0].setPos(self.zc)
419
+ self.hLineOrtho[0].setPos(self.yortho)
420
+
421
+ def update_ortho(self):
422
+ if self.NZ > 1 and self.orthobtn.isChecked():
423
+ dzcurrent = self.dz
424
+ self.dz = min(100, max(3, int(self.dzedit.text())))
425
+ self.zaspect = max(0.01, min(100., float(self.zaspectedit.text())))
426
+ self.dzedit.setText(str(self.dz))
427
+ self.zaspectedit.setText(str(self.zaspect))
428
+ if self.dz != dzcurrent:
429
+ self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
430
+ self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
431
+ dztot = min(self.NZ, self.dz * 2)
432
+ y = self.yortho
433
+ x = self.xortho
434
+ z = self.currentZ
435
+ if dztot == self.NZ:
436
+ zmin, zmax = 0, self.NZ
437
+ else:
438
+ if z - self.dz < 0:
439
+ zmin = 0
440
+ zmax = zmin + self.dz * 2
441
+ elif z + self.dz >= self.NZ:
442
+ zmax = self.NZ
443
+ zmin = zmax - self.dz * 2
444
+ else:
445
+ zmin, zmax = z - self.dz, z + self.dz
446
+ self.zc = z - zmin
447
+ self.update_crosshairs()
448
+ if self.view == 0 or self.view == 4:
449
+ for j in range(2):
450
+ if j == 0:
451
+ if self.view == 0:
452
+ image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy()
453
+ else:
454
+ image = self.stack_filtered[zmin:zmax, :,
455
+ x].transpose(1, 0, 2).copy()
456
+ else:
457
+ image = self.stack[
458
+ zmin:zmax,
459
+ y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax,
460
+ y, :].copy()
461
+ if self.nchan == 1:
462
+ # show single channel
463
+ image = image[..., 0]
464
+ if self.color == 0:
465
+ self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
466
+ if self.nchan > 1:
467
+ levels = np.array([
468
+ self.saturation[0][self.currentZ],
469
+ self.saturation[1][self.currentZ],
470
+ self.saturation[2][self.currentZ]
471
+ ])
472
+ self.imgOrtho[j].setLevels(levels)
473
+ else:
474
+ self.imgOrtho[j].setLevels(
475
+ self.saturation[0][self.currentZ])
476
+ elif self.color > 0 and self.color < 4:
477
+ if self.nchan > 1:
478
+ image = image[..., self.color - 1]
479
+ self.imgOrtho[j].setImage(image, autoLevels=False,
480
+ lut=self.cmap[self.color])
481
+ if self.nchan > 1:
482
+ self.imgOrtho[j].setLevels(
483
+ self.saturation[self.color - 1][self.currentZ])
484
+ else:
485
+ self.imgOrtho[j].setLevels(
486
+ self.saturation[0][self.currentZ])
487
+ elif self.color == 4:
488
+ if image.ndim > 2:
489
+ image = image.astype("float32").mean(axis=2).astype("uint8")
490
+ self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
491
+ self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
492
+ elif self.color == 5:
493
+ if image.ndim > 2:
494
+ image = image.astype("float32").mean(axis=2).astype("uint8")
495
+ self.imgOrtho[j].setImage(image, autoLevels=False,
496
+ lut=self.cmap[0])
497
+ self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
498
+ self.pOrtho[0].setAspectLocked(lock=True, ratio=self.zaspect)
499
+ self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect)
500
+
501
+ else:
502
+ image = np.zeros((10, 10), "uint8")
503
+ self.imgOrtho[0].setImage(image, autoLevels=False, lut=None)
504
+ self.imgOrtho[0].setLevels([0.0, 255.0])
505
+ self.imgOrtho[1].setImage(image, autoLevels=False, lut=None)
506
+ self.imgOrtho[1].setLevels([0.0, 255.0])
507
+
508
+ zrange = zmax - zmin
509
+ self.layer_ortho = [
510
+ np.zeros((self.Ly, zrange, 4), "uint8"),
511
+ np.zeros((zrange, self.Lx, 4), "uint8")
512
+ ]
513
+ if self.masksOn:
514
+ for j in range(2):
515
+ if j == 0:
516
+ cp = self.cellpix[zmin:zmax, :, x].T
517
+ else:
518
+ cp = self.cellpix[zmin:zmax, y]
519
+ self.layer_ortho[j][..., :3] = self.cellcolors[cp, :]
520
+ self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8")
521
+ if self.selected > 0:
522
+ self.layer_ortho[j][cp == self.selected] = np.array(
523
+ [255, 255, 255, self.opacity])
524
+
525
+ if self.outlinesOn:
526
+ for j in range(2):
527
+ if j == 0:
528
+ op = self.outpix[zmin:zmax, :, x].T
529
+ else:
530
+ op = self.outpix[zmin:zmax, y]
531
+ self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8")
532
+
533
+ for j in range(2):
534
+ self.layerOrtho[j].setImage(self.layer_ortho[j])
535
+ self.win.show()
536
+ self.show()
537
+
538
+ def toggle_ortho(self):
539
+ if self.orthobtn.isChecked():
540
+ self.add_orthoviews()
541
+ else:
542
+ self.remove_orthoviews()
543
+
544
+ def plot_clicked(self, event):
545
+ if event.button()==QtCore.Qt.LeftButton \
546
+ and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
547
+ and not self.removing_region:
548
+ if event.double():
549
+ try:
550
+ self.p0.setYRange(0, self.Ly + self.pr)
551
+ except:
552
+ self.p0.setYRange(0, self.Ly)
553
+ self.p0.setXRange(0, self.Lx)
554
+ elif self.loaded and not self.in_stroke:
555
+ if self.orthobtn.isChecked():
556
+ items = self.win.scene().items(event.scenePos())
557
+ for x in items:
558
+ if x == self.p0:
559
+ pos = self.p0.mapSceneToView(event.scenePos())
560
+ x = int(pos.x())
561
+ y = int(pos.y())
562
+ if y >= 0 and y < self.Ly and x >= 0 and x < self.Lx:
563
+ self.yortho = y
564
+ self.xortho = x
565
+ self.update_ortho()
566
+
567
+ def update_plot(self):
568
+ super().update_plot()
569
+ if self.NZ > 1 and self.orthobtn.isChecked():
570
+ self.update_ortho()
571
+ self.win.show()
572
+ self.show()
573
+
574
+ def keyPressEvent(self, event):
575
+ if self.loaded:
576
+ if not (event.modifiers() &
577
+ (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
578
+ QtCore.Qt.AltModifier) or self.in_stroke):
579
+ updated = False
580
+ if len(self.current_point_set) > 0:
581
+ if event.key() == QtCore.Qt.Key_Return:
582
+ self.add_set()
583
+ if self.NZ > 1:
584
+ if event.key() == QtCore.Qt.Key_Left:
585
+ self.currentZ = max(0, self.currentZ - 1)
586
+ self.scroll.setValue(self.currentZ)
587
+ updated = True
588
+ elif event.key() == QtCore.Qt.Key_Right:
589
+ self.currentZ = min(self.NZ - 1, self.currentZ + 1)
590
+ self.scroll.setValue(self.currentZ)
591
+ updated = True
592
+ else:
593
+ nviews = self.ViewDropDown.count() - 1
594
+ nviews += int(
595
+ self.ViewDropDown.model().item(self.ViewDropDown.count() -
596
+ 1).isEnabled())
597
+ if event.key() == QtCore.Qt.Key_X:
598
+ self.MCheckBox.toggle()
599
+ if event.key() == QtCore.Qt.Key_Z:
600
+ self.OCheckBox.toggle()
601
+ if event.key() == QtCore.Qt.Key_Left or event.key(
602
+ ) == QtCore.Qt.Key_A:
603
+ self.currentZ = max(0, self.currentZ - 1)
604
+ self.scroll.setValue(self.currentZ)
605
+ updated = True
606
+ elif event.key() == QtCore.Qt.Key_Right or event.key(
607
+ ) == QtCore.Qt.Key_D:
608
+ self.currentZ = min(self.NZ - 1, self.currentZ + 1)
609
+ self.scroll.setValue(self.currentZ)
610
+ updated = True
611
+ elif event.key() == QtCore.Qt.Key_PageDown:
612
+ self.view = (self.view + 1) % (nviews)
613
+ self.ViewDropDown.setCurrentIndex(self.view)
614
+ elif event.key() == QtCore.Qt.Key_PageUp:
615
+ self.view = (self.view - 1) % (nviews)
616
+ self.ViewDropDown.setCurrentIndex(self.view)
617
+
618
+ # can change background or stroke size if cell not finished
619
+ if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
620
+ self.color = (self.color - 1) % (6)
621
+ self.RGBDropDown.setCurrentIndex(self.color)
622
+ elif event.key() == QtCore.Qt.Key_Down or event.key(
623
+ ) == QtCore.Qt.Key_S:
624
+ self.color = (self.color + 1) % (6)
625
+ self.RGBDropDown.setCurrentIndex(self.color)
626
+ elif event.key() == QtCore.Qt.Key_R:
627
+ if self.color != 1:
628
+ self.color = 1
629
+ else:
630
+ self.color = 0
631
+ self.RGBDropDown.setCurrentIndex(self.color)
632
+ elif event.key() == QtCore.Qt.Key_G:
633
+ if self.color != 2:
634
+ self.color = 2
635
+ else:
636
+ self.color = 0
637
+ self.RGBDropDown.setCurrentIndex(self.color)
638
+ elif event.key() == QtCore.Qt.Key_B:
639
+ if self.color != 3:
640
+ self.color = 3
641
+ else:
642
+ self.color = 0
643
+ self.RGBDropDown.setCurrentIndex(self.color)
644
+ elif (event.key() == QtCore.Qt.Key_Comma or
645
+ event.key() == QtCore.Qt.Key_Period):
646
+ count = self.BrushChoose.count()
647
+ gci = self.BrushChoose.currentIndex()
648
+ if event.key() == QtCore.Qt.Key_Comma:
649
+ gci = max(0, gci - 1)
650
+ else:
651
+ gci = min(count - 1, gci + 1)
652
+ self.BrushChoose.setCurrentIndex(gci)
653
+ self.brush_choose()
654
+ if not updated:
655
+ self.update_plot()
656
+ if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
657
+ self.p0.keyPressEvent(event)
658
+
659
+ def update_ztext(self):
660
+ zpos = self.currentZ
661
+ try:
662
+ zpos = int(self.zpos.text())
663
+ except:
664
+ print("ERROR: zposition is not a number")
665
+ self.currentZ = max(0, min(self.NZ - 1, zpos))
666
+ self.zpos.setText(str(self.currentZ))
667
+ self.scroll.setValue(self.currentZ)
models/seg_post_model/cellpose/gui/guihelpwindowtext.html ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <qt>
2
+ <p class="has-line-data" data-line-start="5" data-line-end="6">
3
+ <b>Main GUI mouse controls:</b>
4
+ </p>
5
+ <ul>
6
+ <li class="has-line-data" data-line-start="7" data-line-end="8">Pan = left-click + drag</li>
7
+ <li class="has-line-data" data-line-start="8" data-line-end="9">Zoom = scroll wheel (or +/= and - buttons)</li>
8
+ <li class="has-line-data" data-line-start="9" data-line-end="10">Full view = double left-click</li>
9
+ <li class="has-line-data" data-line-start="10" data-line-end="11">Select mask = left-click on mask</li>
10
+ <li class="has-line-data" data-line-start="11" data-line-end="12">Delete mask = Ctrl (or COMMAND on Mac) +
11
+ left-click
12
+ </li>
13
+ <li class="has-line-data" data-line-start="11" data-line-end="12">Merge masks = Alt + left-click (will merge
14
+ last two)
15
+ </li>
16
+ <li class="has-line-data" data-line-start="12" data-line-end="13">Start draw mask = right-click</li>
17
+ <li class="has-line-data" data-line-start="13" data-line-end="15">End draw mask = right-click, or return to
18
+ circle at beginning
19
+ </li>
20
+ </ul>
21
+ <p class="has-line-data" data-line-start="15" data-line-end="16">Overlaps in masks are NOT allowed. If you
22
+ draw a mask on top of another mask, it is cropped so that it doesn’t overlap with the old mask. Masks in 2D
23
+ should be single strokes (single stroke is checked). If you want to draw masks in 3D (experimental), then
24
+ you can turn this option off and draw a stroke on each plane with the cell and then press ENTER. 3D
25
+ labelling will fill in planes that you have not labelled so that you do not have to as densely label.
26
+ </p>
27
+ <p class="has-line-data" data-line-start="17" data-line-end="18"> <b>!NOTE!:</b> The GUI automatically saves after
28
+ you draw a mask in 2D but NOT after 3D mask drawing and NOT after segmentation. Save in the file menu or
29
+ with Ctrl+S. The output file is in the same folder as the loaded image with <code>_seg.npy</code> appended.
30
+ </p>
31
+
32
+ <p class="has-line-data" data-line-start="19" data-line-end="20"> <b>Bulk Mask Deletion</b>
33
+ Clicking the 'delete multiple' button will allow you to select and delete multiple masks at once.
34
+ Masks can be deselected by clicking on them again. Once you have selected all the masks you want to delete,
35
+ click the 'done' button to delete them.
36
+ <br>
37
+ <br>
38
+ Alternatively, you can create a rectangular region to delete a regions of masks by clicking the
39
+ 'delete multiple' button, and then moving and/or resizing the region to select the masks you want to delete.
40
+ Once you have selected the masks you want to delete, click the 'done' button to delete them.
41
+ <br>
42
+ <br>
43
+ At any point in the process, you can click the 'cancel' button to cancel the bulk deletion.
44
+ </p>
45
+ <hr>
46
+ <table class="table table-striped table-bordered">
47
+ <br>
48
+ <br>
49
+ FYI there are tooltips throughout the GUI (hover over text to see)
50
+ <br>
51
+ <thead>
52
+ <tr>
53
+ <th>Keyboard shortcuts</th>
54
+ <th>Description</th>
55
+ </tr>
56
+ </thead>
57
+ <tbody>
58
+ <tr>
59
+ <td>=/+ button // - button</td>
60
+ <td>zoom in // zoom out</td>
61
+ </tr>
62
+ <tr>
63
+ <td>CTRL+Z</td>
64
+ <td>undo previously drawn mask/stroke</td>
65
+ </tr>
66
+ <tr>
67
+ <td>CTRL+Y</td>
68
+ <td>undo remove mask</td>
69
+ </tr>
70
+ <tr>
71
+ <td>CTRL+0</td>
72
+ <td>clear all masks</td>
73
+ </tr>
74
+ <tr>
75
+ <td>CTRL+L</td>
76
+ <td>load image (can alternatively drag and drop image)</td>
77
+ </tr>
78
+ <tr>
79
+ <td>CTRL+S</td>
80
+ <td>SAVE MASKS IN IMAGE to <code>_seg.npy</code> file</td>
81
+ </tr>
82
+ <tr>
83
+ <td>CTRL+T</td>
84
+ <td>train model using _seg.npy files in folder
85
+ </tr>
86
+ <tr>
87
+ <td>CTRL+P</td>
88
+ <td>load <code>_seg.npy</code> file (note: it will load automatically with image if it exists)</td>
89
+ </tr>
90
+ <tr>
91
+ <td>CTRL+M</td>
92
+ <td>load masks file (must be same size as image with 0 for NO mask, and 1,2,3… for masks)</td>
93
+ </tr>
94
+ <tr>
95
+ <td>CTRL+N</td>
96
+ <td>save masks as PNG</td>
97
+ </tr>
98
+ <tr>
99
+ <td>CTRL+R</td>
100
+ <td>save ROIs to native ImageJ ROI format</td>
101
+ </tr>
102
+ <tr>
103
+ <td>CTRL+F</td>
104
+ <td>save flows to image file</td>
105
+ </tr>
106
+ <tr>
107
+ <td>A/D or LEFT/RIGHT</td>
108
+ <td>cycle through images in current directory</td>
109
+ </tr>
110
+ <tr>
111
+ <td>W/S or UP/DOWN</td>
112
+ <td>change color (RGB/gray/red/green/blue)</td>
113
+ </tr>
114
+ <tr>
115
+ <td>R / G / B</td>
116
+ <td>toggle between RGB and Red or Green or Blue</td>
117
+ </tr>
118
+ <tr>
119
+ <td>PAGE-UP / PAGE-DOWN</td>
120
+ <td>change to flows and cell prob views (if segmentation computed)</td>
121
+ </tr>
122
+ <tr>
123
+ <td>X</td>
124
+ <td>turn masks ON or OFF</td>
125
+ </tr>
126
+ <tr>
127
+ <td>Z</td>
128
+ <td>toggle outlines ON or OFF</td>
129
+ </tr>
130
+ <tr>
131
+ <td>, / .</td>
132
+ <td>increase / decrease brush size for drawing masks</td>
133
+ </tr>
134
+ </tbody>
135
+ </table>
136
+ <p class="has-line-data" data-line-start="36" data-line-end="37"><strong>Segmentation options
137
+ (2D only) </strong></p>
138
+ <p class="has-line-data" data-line-start="38" data-line-end="39">use GPU: if you have specially
139
+ installed the cuda version of torch, then you can activate this. Due to the size of the
140
+ transformer network, it will greatly speed up the processing time.</p>
141
+ <p class="has-line-data" data-line-start="40" data-line-end="41">There are no channel options
142
+ in v4.0.1+ since all 3 channels are used for segmentation. </p>
143
+ </qt>
models/seg_post_model/cellpose/gui/guiparts.py ADDED
@@ -0,0 +1,793 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ from qtpy import QtGui, QtCore
5
+ from qtpy.QtGui import QPixmap, QDoubleValidator
6
+ from qtpy.QtWidgets import QWidget, QDialog, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox, QVBoxLayout
7
+ import pyqtgraph as pg
8
+ import numpy as np
9
+ import pathlib, os
10
+
11
+
12
+ def stylesheet():
13
+ return """
14
+ QToolTip {
15
+ background-color: black;
16
+ color: white;
17
+ border: black solid 1px
18
+ }
19
+ QComboBox {color: white;
20
+ background-color: rgb(40,40,40);}
21
+ QComboBox::item:enabled { color: white;
22
+ background-color: rgb(40,40,40);
23
+ selection-color: white;
24
+ selection-background-color: rgb(50,100,50);}
25
+ QComboBox::item:!enabled {
26
+ background-color: rgb(40,40,40);
27
+ color: rgb(100,100,100);
28
+ }
29
+ QScrollArea > QWidget > QWidget
30
+ {
31
+ background: transparent;
32
+ border: none;
33
+ margin: 0px 0px 0px 0px;
34
+ }
35
+
36
+ QGroupBox
37
+ { border: 1px solid white; color: rgb(255,255,255);
38
+ border-radius: 6px;
39
+ margin-top: 8px;
40
+ padding: 0px 0px;}
41
+
42
+ QPushButton:pressed {Text-align: center;
43
+ background-color: rgb(150,50,150);
44
+ border-color: white;
45
+ color:white;}
46
+ QToolTip {
47
+ background-color: black;
48
+ color: white;
49
+ border: black solid 1px
50
+ }
51
+ QPushButton:!pressed {Text-align: center;
52
+ background-color: rgb(50,50,50);
53
+ border-color: white;
54
+ color:white;}
55
+ QToolTip {
56
+ background-color: black;
57
+ color: white;
58
+ border: black solid 1px
59
+ }
60
+ QPushButton:disabled {Text-align: center;
61
+ background-color: rgb(30,30,30);
62
+ border-color: white;
63
+ color:rgb(80,80,80);}
64
+ QToolTip {
65
+ background-color: black;
66
+ color: white;
67
+ border: black solid 1px
68
+ }
69
+
70
+ """
71
+
72
+
73
+ class DarkPalette(QtGui.QPalette):
74
+ """Class that inherits from pyqtgraph.QtGui.QPalette and renders dark colours for the application.
75
+ (from pykilosort/kilosort4)
76
+ """
77
+
78
+ def __init__(self):
79
+ QtGui.QPalette.__init__(self)
80
+ self.setup()
81
+
82
+ def setup(self):
83
+ self.setColor(QtGui.QPalette.Window, QtGui.QColor(40, 40, 40))
84
+ self.setColor(QtGui.QPalette.WindowText, QtGui.QColor(255, 255, 255))
85
+ self.setColor(QtGui.QPalette.Base, QtGui.QColor(34, 27, 24))
86
+ self.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(53, 50, 47))
87
+ self.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(255, 255, 255))
88
+ self.setColor(QtGui.QPalette.ToolTipText, QtGui.QColor(255, 255, 255))
89
+ self.setColor(QtGui.QPalette.Text, QtGui.QColor(255, 255, 255))
90
+ self.setColor(QtGui.QPalette.Button, QtGui.QColor(53, 50, 47))
91
+ self.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(255, 255, 255))
92
+ self.setColor(QtGui.QPalette.BrightText, QtGui.QColor(255, 0, 0))
93
+ self.setColor(QtGui.QPalette.Link, QtGui.QColor(42, 130, 218))
94
+ self.setColor(QtGui.QPalette.Highlight, QtGui.QColor(42, 130, 218))
95
+ self.setColor(QtGui.QPalette.HighlightedText, QtGui.QColor(0, 0, 0))
96
+ self.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Text,
97
+ QtGui.QColor(128, 128, 128))
98
+ self.setColor(
99
+ QtGui.QPalette.Disabled,
100
+ QtGui.QPalette.ButtonText,
101
+ QtGui.QColor(128, 128, 128),
102
+ )
103
+ self.setColor(
104
+ QtGui.QPalette.Disabled,
105
+ QtGui.QPalette.WindowText,
106
+ QtGui.QColor(128, 128, 128),
107
+ )
108
+
109
+
110
+ # def create_channel_choose():
111
+ # # choose channel
112
+ # ChannelChoose = [QComboBox(), QComboBox()]
113
+ # ChannelLabels = []
114
+ # ChannelChoose[0].addItems(["gray", "red", "green", "blue"])
115
+ # ChannelChoose[1].addItems(["none", "red", "green", "blue"])
116
+ # cstr = ["chan to segment:", "chan2 (optional): "]
117
+ # for i in range(2):
118
+ # ChannelLabels.append(QLabel(cstr[i]))
119
+ # if i == 0:
120
+ # ChannelLabels[i].setToolTip(
121
+ # "this is the channel in which the cytoplasm or nuclei exist \
122
+ # that you want to segment")
123
+ # ChannelChoose[i].setToolTip(
124
+ # "this is the channel in which the cytoplasm or nuclei exist \
125
+ # that you want to segment")
126
+ # else:
127
+ # ChannelLabels[i].setToolTip(
128
+ # "if <em>cytoplasm</em> model is chosen, and you also have a \
129
+ # nuclear channel, then choose the nuclear channel for this option")
130
+ # ChannelChoose[i].setToolTip(
131
+ # "if <em>cytoplasm</em> model is chosen, and you also have a \
132
+ # nuclear channel, then choose the nuclear channel for this option")
133
+
134
+ # return ChannelChoose, ChannelLabels
135
+
136
+
137
+ class ModelButton(QPushButton):
138
+
139
+ def __init__(self, parent, model_name, text):
140
+ super().__init__()
141
+ self.setEnabled(False)
142
+ self.setText(text)
143
+ self.setFont(parent.boldfont)
144
+ self.clicked.connect(lambda: self.press(parent))
145
+ self.model_name = "cpsam"
146
+
147
+ def press(self, parent):
148
+ parent.compute_segmentation(model_name="cpsam")
149
+
150
+
151
+ class FilterButton(QPushButton):
152
+
153
+ def __init__(self, parent, text):
154
+ super().__init__()
155
+ self.setEnabled(False)
156
+ self.model_type = text
157
+ self.setText(text)
158
+ self.setFont(parent.medfont)
159
+ self.clicked.connect(lambda: self.press(parent))
160
+
161
+ def press(self, parent):
162
+ if self.model_type == "filter":
163
+ parent.restore = "filter"
164
+ normalize_params = parent.get_normalize_params()
165
+ if (normalize_params["sharpen_radius"] == 0 and
166
+ normalize_params["smooth_radius"] == 0 and
167
+ normalize_params["tile_norm_blocksize"] == 0):
168
+ print(
169
+ "GUI_ERROR: no filtering settings on (use custom filter settings)")
170
+ parent.restore = None
171
+ return
172
+ parent.restore = self.model_type
173
+ parent.compute_saturation()
174
+ # elif self.model_type != "none":
175
+ # parent.compute_denoise_model(model_type=self.model_type)
176
+ else:
177
+ parent.clear_restore()
178
+ # parent.set_restore_button()
179
+
180
+
181
+ class ObservableVariable(QtCore.QObject):
182
+ valueChanged = QtCore.Signal(object)
183
+
184
+ def __init__(self, initial=None):
185
+ super().__init__()
186
+ self._value = initial
187
+
188
+ def set(self, new_value):
189
+ """ Use this method to get emit the value changing and update the ROI count"""
190
+ if new_value != self._value:
191
+ self._value = new_value
192
+ self.valueChanged.emit(new_value)
193
+
194
+ def get(self):
195
+ return self._value
196
+
197
+ def __call__(self):
198
+ return self._value
199
+
200
+ def reset(self):
201
+ self.set(0)
202
+
203
+ def __iadd__(self, amount):
204
+ if not isinstance(amount, (int, float)):
205
+ raise TypeError("Value must be numeric.")
206
+ self.set(self._value + amount)
207
+ return self
208
+
209
+ def __radd__(self, other):
210
+ return other + self._value
211
+
212
+ def __add__(self, other):
213
+ return other + self._value
214
+
215
+ def __isub__(self, amount):
216
+ if not isinstance(amount, (int, float)):
217
+ raise TypeError("Value must be numeric.")
218
+ self.set(self._value - amount)
219
+ return self
220
+
221
+ def __str__(self):
222
+ return str(self._value)
223
+
224
+ def __lt__(self, x):
225
+ return self._value < x
226
+
227
+ def __gt__(self, x):
228
+ return self._value > x
229
+
230
+ def __eq__(self, x):
231
+ return self._value == x
232
+
233
+
234
+ class NormalizationSettings(QWidget):
235
+ # TODO
236
+ pass
237
+
238
+
239
+ class SegmentationSettings(QWidget):
240
+ """ Container for gui settings. Validation is done automatically so any attributes can
241
+ be acessed without concern.
242
+ """
243
+ def __init__(self, font):
244
+ super().__init__()
245
+
246
+ # Put everything in a grid layout:
247
+ grid_layout = QGridLayout()
248
+ widget_container = QWidget()
249
+ widget_container.setLayout(grid_layout)
250
+ row = 0
251
+
252
+ ########################### Diameter ###########################
253
+ # TODO: Validate inputs
254
+ diam_qlabel = QLabel("diameter:")
255
+ diam_qlabel.setToolTip("diameter of cells in pixels. If not 30, image will be resized to this")
256
+ diam_qlabel.setFont(font)
257
+ grid_layout.addWidget(diam_qlabel, row, 0, 1, 2)
258
+ self.diameter_box = QLineEdit()
259
+ self.diameter_box.setToolTip("diameter of cells in pixels. If not blank, image will be resized relative to 30 pixel cell diameters")
260
+ self.diameter_box.setFont(font)
261
+ self.diameter_box.setFixedWidth(40)
262
+ self.diameter_box.setText(' ')
263
+ grid_layout.addWidget(self.diameter_box, row, 2, 1, 2)
264
+
265
+ row += 1
266
+
267
+ ########################### Flow threshold ###########################
268
+ # TODO: Validate inputs
269
+ flow_threshold_qlabel = QLabel("flow\nthreshold:")
270
+ flow_threshold_qlabel.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run")
271
+ flow_threshold_qlabel.setFont(font)
272
+ grid_layout.addWidget(flow_threshold_qlabel, row, 0, 1, 2)
273
+ self.flow_threshold_box = QLineEdit()
274
+ self.flow_threshold_box.setText("0.4")
275
+ self.flow_threshold_box.setFixedWidth(40)
276
+ self.flow_threshold_box.setFont(font)
277
+ grid_layout.addWidget(self.flow_threshold_box, row, 2, 1, 2)
278
+ self.flow_threshold_box.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run")
279
+
280
+ ########################### Cellprob threshold ###########################
281
+ # TODO: Validate inputs
282
+ cellprob_qlabel = QLabel("cellprob\nthreshold:")
283
+ cellprob_qlabel.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run")
284
+ cellprob_qlabel.setFont(font)
285
+ grid_layout.addWidget(cellprob_qlabel, row, 4, 1, 2)
286
+ self.cellprob_threshold_box = QLineEdit()
287
+ self.cellprob_threshold_box.setText("0.0")
288
+ self.cellprob_threshold_box.setFixedWidth(40)
289
+ self.cellprob_threshold_box.setFont(font)
290
+ self.cellprob_threshold_box.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run")
291
+ grid_layout.addWidget(self.cellprob_threshold_box, row, 6, 1, 2)
292
+
293
+ row += 1
294
+
295
+ ########################### Norm percentiles ###########################
296
+ norm_percentiles_qlabel = QLabel("norm percentiles:")
297
+ norm_percentiles_qlabel.setToolTip("sets normalization percentiles for segmentation and denoising\n(pixels at lower percentile set to 0.0 and at upper set to 1.0 for network)")
298
+ norm_percentiles_qlabel.setFont(font)
299
+ grid_layout.addWidget(norm_percentiles_qlabel, row, 0, 1, 8)
300
+
301
+ row += 1
302
+ validator = QDoubleValidator(0.0, 100.0, 2)
303
+ validator.setNotation(QDoubleValidator.StandardNotation)
304
+
305
+ low_norm_qlabel = QLabel('lower:')
306
+ low_norm_qlabel.setToolTip("pixels at this percentile set to 0 (default 1.0)")
307
+ low_norm_qlabel.setFont(font)
308
+ grid_layout.addWidget(low_norm_qlabel, row, 0, 1, 2)
309
+ self.norm_percentile_low_box = QLineEdit()
310
+ self.norm_percentile_low_box.setText("1.0")
311
+ self.norm_percentile_low_box.setFont(font)
312
+ self.norm_percentile_low_box.setFixedWidth(40)
313
+ self.norm_percentile_low_box.setToolTip("pixels at this percentile set to 0 (default 1.0)")
314
+ self.norm_percentile_low_box.setValidator(validator)
315
+ self.norm_percentile_low_box.editingFinished.connect(self.validate_normalization_range)
316
+ grid_layout.addWidget(self.norm_percentile_low_box, row, 2, 1, 1)
317
+
318
+ high_norm_qlabel = QLabel('upper:')
319
+ high_norm_qlabel.setToolTip("pixels at this percentile set to 1 (default 99.0)")
320
+ high_norm_qlabel.setFont(font)
321
+ grid_layout.addWidget(high_norm_qlabel, row, 4, 1, 2)
322
+ self.norm_percentile_high_box = QLineEdit()
323
+ self.norm_percentile_high_box.setText("99.0")
324
+ self.norm_percentile_high_box.setFont(font)
325
+ self.norm_percentile_high_box.setFixedWidth(40)
326
+ self.norm_percentile_high_box.setToolTip("pixels at this percentile set to 1 (default 99.0)")
327
+ self.norm_percentile_high_box.setValidator(validator)
328
+ self.norm_percentile_high_box.editingFinished.connect(self.validate_normalization_range)
329
+ grid_layout.addWidget(self.norm_percentile_high_box, row, 6, 1, 2)
330
+
331
+ row += 1
332
+
333
+ ########################### niter ###########################
334
+ # TODO: change this to follow the same default logic as 'diameter' above
335
+ # TODO: input validation
336
+ niter_qlabel = QLabel("niter dynamics:")
337
+ niter_qlabel.setFont(font)
338
+ niter_qlabel.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria")
339
+ grid_layout.addWidget(niter_qlabel, row, 0, 1, 4)
340
+ self.niter_box = QLineEdit()
341
+ self.niter_box.setText("0")
342
+ self.niter_box.setFixedWidth(40)
343
+ self.niter_box.setFont(font)
344
+ self.niter_box.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria")
345
+ grid_layout.addWidget(self.niter_box, row, 4, 1, 2)
346
+
347
+ self.setLayout(grid_layout)
348
+
349
+ def validate_normalization_range(self):
350
+ low_text = self.norm_percentile_low_box.text()
351
+ high_text = self.norm_percentile_high_box.text()
352
+
353
+ if not low_text or low_text.isspace():
354
+ self.norm_percentile_low_box.setText('1.0')
355
+ low_text = '1.0'
356
+ elif not high_text or high_text.isspace():
357
+ self.norm_percentile_high_box.setText('1.0')
358
+ high_text = '99.0'
359
+
360
+ low = float(low_text)
361
+ high = float(high_text)
362
+
363
+ if low >= high:
364
+ # Invalid: show error and mark fields
365
+ self.norm_percentile_low_box.setStyleSheet("border: 1px solid red;")
366
+ self.norm_percentile_high_box.setStyleSheet("border: 1px solid red;")
367
+ else:
368
+ # Valid: clear style
369
+ self.norm_percentile_low_box.setStyleSheet("")
370
+ self.norm_percentile_high_box.setStyleSheet("")
371
+
372
+ @property
373
+ def low_percentile(self):
374
+ """ Also validate the low input by returning 1.0 if text doesn't work """
375
+ low_text = self.norm_percentile_low_box.text()
376
+ if not low_text or low_text.isspace():
377
+ self.norm_percentile_low_box.setText('1.0')
378
+ low_text = '1.0'
379
+ return float(self.norm_percentile_low_box.text())
380
+
381
+ @property
382
+ def high_percentile(self):
383
+ """ Also validate the high input by returning 99.0 if text doesn't work """
384
+ high_text = self.norm_percentile_high_box.text()
385
+ if not high_text or high_text.isspace():
386
+ self.norm_percentile_high_box.setText('99.0')
387
+ high_text = '99.0'
388
+ return float(self.norm_percentile_high_box.text())
389
+
390
+ @property
391
+ def diameter(self):
392
+ """ Get the diameter from the diameter box, if box isn't a number return None"""
393
+ try:
394
+ d = float(self.diameter_box.text())
395
+ except ValueError:
396
+ d = None
397
+ return d
398
+
399
+ @property
400
+ def flow_threshold(self):
401
+ return float(self.flow_threshold_box.text())
402
+
403
+ @property
404
+ def cellprob_threshold(self):
405
+ return float(self.cellprob_threshold_box.text())
406
+
407
+ @property
408
+ def niter(self):
409
+ num = int(self.niter_box.text())
410
+ if num < 1:
411
+ self.niter_box.setText('200')
412
+ return 200
413
+ else:
414
+ return num
415
+
416
+
417
+
418
+ class TrainWindow(QDialog):
419
+
420
+ def __init__(self, parent, model_strings):
421
+ super().__init__(parent)
422
+ self.setGeometry(100, 100, 900, 550)
423
+ self.setWindowTitle("train settings")
424
+ self.win = QWidget(self)
425
+ self.l0 = QGridLayout()
426
+ self.win.setLayout(self.l0)
427
+
428
+ yoff = 0
429
+ qlabel = QLabel("train model w/ images + _seg.npy in current folder >>")
430
+ qlabel.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
431
+
432
+ qlabel.setAlignment(QtCore.Qt.AlignVCenter)
433
+ self.l0.addWidget(qlabel, yoff, 0, 1, 2)
434
+
435
+ # choose initial model
436
+ yoff += 1
437
+ self.ModelChoose = QComboBox()
438
+ self.ModelChoose.addItems(model_strings)
439
+ self.ModelChoose.setFixedWidth(150)
440
+ self.ModelChoose.setCurrentIndex(parent.training_params["model_index"])
441
+ self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1)
442
+ qlabel = QLabel("initial model: ")
443
+ qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
444
+ self.l0.addWidget(qlabel, yoff, 0, 1, 1)
445
+
446
+ # choose parameters
447
+ labels = ["learning_rate", "weight_decay", "n_epochs", "model_name"]
448
+ self.edits = []
449
+ yoff += 1
450
+ for i, label in enumerate(labels):
451
+ qlabel = QLabel(label)
452
+ qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
453
+ self.l0.addWidget(qlabel, i + yoff, 0, 1, 1)
454
+ self.edits.append(QLineEdit())
455
+ self.edits[-1].setText(str(parent.training_params[label]))
456
+ self.edits[-1].setFixedWidth(200)
457
+ self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1)
458
+
459
+ yoff += len(labels)
460
+
461
+ yoff += 1
462
+ self.use_norm = QCheckBox(f"use restored/filtered image")
463
+ self.use_norm.setChecked(True)
464
+
465
+ yoff += 2
466
+ qlabel = QLabel(
467
+ "(to remove files, click cancel then remove \nfrom folder and reopen train window)"
468
+ )
469
+ self.l0.addWidget(qlabel, yoff, 0, 2, 4)
470
+
471
+ # click button
472
+ yoff += 3
473
+ QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel
474
+ self.buttonBox = QDialogButtonBox(QBtn)
475
+ self.buttonBox.accepted.connect(lambda: self.accept(parent))
476
+ self.buttonBox.rejected.connect(self.reject)
477
+ self.l0.addWidget(self.buttonBox, yoff, 0, 1, 4)
478
+
479
+ # list files in folder
480
+ qlabel = QLabel("filenames")
481
+ qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
482
+ self.l0.addWidget(qlabel, 0, 4, 1, 1)
483
+ qlabel = QLabel("# of masks")
484
+ qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
485
+ self.l0.addWidget(qlabel, 0, 5, 1, 1)
486
+
487
+ for i in range(10):
488
+ if i > len(parent.train_files) - 1:
489
+ break
490
+ elif i == 9 and len(parent.train_files) > 10:
491
+ label = "..."
492
+ nmasks = "..."
493
+ else:
494
+ label = os.path.split(parent.train_files[i])[-1]
495
+ nmasks = str(parent.train_labels[i].max())
496
+ qlabel = QLabel(label)
497
+ self.l0.addWidget(qlabel, i + 1, 4, 1, 1)
498
+ qlabel = QLabel(nmasks)
499
+ qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
500
+ self.l0.addWidget(qlabel, i + 1, 5, 1, 1)
501
+
502
+ def accept(self, parent):
503
+ # set training params
504
+ parent.training_params = {
505
+ "model_index": self.ModelChoose.currentIndex(),
506
+ "learning_rate": float(self.edits[0].text()),
507
+ "weight_decay": float(self.edits[1].text()),
508
+ "n_epochs": int(self.edits[2].text()),
509
+ "model_name": self.edits[3].text(),
510
+ #"use_norm": True if self.use_norm.isChecked() else False,
511
+ }
512
+ self.done(1)
513
+
514
+
515
+ class ExampleGUI(QDialog):
516
+
517
+ def __init__(self, parent=None):
518
+ super(ExampleGUI, self).__init__(parent)
519
+ self.setGeometry(100, 100, 1300, 900)
520
+ self.setWindowTitle("GUI layout")
521
+ self.win = QWidget(self)
522
+ layout = QGridLayout()
523
+ self.win.setLayout(layout)
524
+ guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png")
525
+ guip_path = str(guip_path.resolve())
526
+ pixmap = QPixmap(guip_path)
527
+ label = QLabel(self)
528
+ label.setPixmap(pixmap)
529
+ pixmap.scaled
530
+ layout.addWidget(label, 0, 0, 1, 1)
531
+
532
+
533
+ class HelpWindow(QDialog):
534
+
535
+ def __init__(self, parent=None):
536
+ super(HelpWindow, self).__init__(parent)
537
+ self.setGeometry(100, 50, 700, 1000)
538
+ self.setWindowTitle("cellpose help")
539
+ self.win = QWidget(self)
540
+ layout = QGridLayout()
541
+ self.win.setLayout(layout)
542
+
543
+ text_file = pathlib.Path(__file__).parent.joinpath("guihelpwindowtext.html")
544
+ with open(str(text_file.resolve()), "r") as f:
545
+ text = f.read()
546
+
547
+ label = QLabel(text)
548
+ label.setFont(QtGui.QFont("Arial", 8))
549
+ label.setWordWrap(True)
550
+ layout.addWidget(label, 0, 0, 1, 1)
551
+ self.show()
552
+
553
+
554
+ class TrainHelpWindow(QDialog):
555
+
556
+ def __init__(self, parent=None):
557
+ super(TrainHelpWindow, self).__init__(parent)
558
+ self.setGeometry(100, 50, 700, 300)
559
+ self.setWindowTitle("training instructions")
560
+ self.win = QWidget(self)
561
+ layout = QGridLayout()
562
+ self.win.setLayout(layout)
563
+
564
+ text_file = pathlib.Path(__file__).parent.joinpath(
565
+ "guitrainhelpwindowtext.html")
566
+ with open(str(text_file.resolve()), "r") as f:
567
+ text = f.read()
568
+
569
+ label = QLabel(text)
570
+ label.setFont(QtGui.QFont("Arial", 8))
571
+ label.setWordWrap(True)
572
+ layout.addWidget(label, 0, 0, 1, 1)
573
+ self.show()
574
+
575
+
576
+ class ViewBoxNoRightDrag(pg.ViewBox):
577
+
578
+ def __init__(self, parent=None, border=None, lockAspect=False, enableMouse=True,
579
+ invertY=False, enableMenu=True, name=None, invertX=False):
580
+ pg.ViewBox.__init__(self, None, border, lockAspect, enableMouse, invertY,
581
+ enableMenu, name, invertX)
582
+ self.parent = parent
583
+ self.axHistoryPointer = -1
584
+
585
+ def keyPressEvent(self, ev):
586
+ """
587
+ This routine should capture key presses in the current view box.
588
+ The following events are implemented:
589
+ +/= : moves forward in the zooming stack (if it exists)
590
+ - : moves backward in the zooming stack (if it exists)
591
+
592
+ """
593
+ ev.accept()
594
+ if ev.text() == "-":
595
+ self.scaleBy([1.1, 1.1])
596
+ elif ev.text() in ["+", "="]:
597
+ self.scaleBy([0.9, 0.9])
598
+ else:
599
+ ev.ignore()
600
+
601
+
602
+ class ImageDraw(pg.ImageItem):
603
+ """
604
+ **Bases:** :class:`GraphicsObject <pyqtgraph.GraphicsObject>`
605
+ GraphicsObject displaying an image. Optimized for rapid update (ie video display).
606
+ This item displays either a 2D numpy array (height, width) or
607
+ a 3D array (height, width, RGBa). This array is optionally scaled (see
608
+ :func:`setLevels <pyqtgraph.ImageItem.setLevels>`) and/or colored
609
+ with a lookup table (see :func:`setLookupTable <pyqtgraph.ImageItem.setLookupTable>`)
610
+ before being displayed.
611
+ ImageItem is frequently used in conjunction with
612
+ :class:`HistogramLUTItem <pyqtgraph.HistogramLUTItem>` or
613
+ :class:`HistogramLUTWidget <pyqtgraph.HistogramLUTWidget>` to provide a GUI
614
+ for controlling the levels and lookup table used to display the image.
615
+ """
616
+
617
+ sigImageChanged = QtCore.Signal()
618
+
619
+ def __init__(self, image=None, viewbox=None, parent=None, **kargs):
620
+ super(ImageDraw, self).__init__()
621
+ self.levels = np.array([0, 255])
622
+ self.lut = None
623
+ self.autoDownsample = False
624
+ self.axisOrder = "row-major"
625
+ self.removable = False
626
+
627
+ self.parent = parent
628
+ self.setDrawKernel(kernel_size=self.parent.brush_size)
629
+ self.parent.current_stroke = []
630
+ self.parent.in_stroke = False
631
+
632
+ def mouseClickEvent(self, ev):
633
+ if (self.parent.masksOn or
634
+ self.parent.outlinesOn) and not self.parent.removing_region:
635
+ is_right_click = ev.button() == QtCore.Qt.RightButton
636
+ if self.parent.loaded \
637
+ and (is_right_click or ev.modifiers() & QtCore.Qt.ShiftModifier and not ev.double())\
638
+ and not self.parent.deleting_multiple:
639
+ if not self.parent.in_stroke:
640
+ ev.accept()
641
+ self.create_start(ev.pos())
642
+ self.parent.stroke_appended = False
643
+ self.parent.in_stroke = True
644
+ self.drawAt(ev.pos(), ev)
645
+ else:
646
+ ev.accept()
647
+ self.end_stroke()
648
+ self.parent.in_stroke = False
649
+ elif not self.parent.in_stroke:
650
+ y, x = int(ev.pos().y()), int(ev.pos().x())
651
+ if y >= 0 and y < self.parent.Ly and x >= 0 and x < self.parent.Lx:
652
+ if ev.button() == QtCore.Qt.LeftButton and not ev.double():
653
+ idx = self.parent.cellpix[self.parent.currentZ][y, x]
654
+ if idx > 0:
655
+ if ev.modifiers() & QtCore.Qt.ControlModifier:
656
+ # delete mask selected
657
+ self.parent.remove_cell(idx)
658
+ elif ev.modifiers() & QtCore.Qt.AltModifier:
659
+ self.parent.merge_cells(idx)
660
+ elif self.parent.masksOn and not self.parent.deleting_multiple:
661
+ self.parent.unselect_cell()
662
+ self.parent.select_cell(idx)
663
+ elif self.parent.deleting_multiple:
664
+ if idx in self.parent.removing_cells_list:
665
+ self.parent.unselect_cell_multi(idx)
666
+ self.parent.removing_cells_list.remove(idx)
667
+ else:
668
+ self.parent.select_cell_multi(idx)
669
+ self.parent.removing_cells_list.append(idx)
670
+
671
+ elif self.parent.masksOn and not self.parent.deleting_multiple:
672
+ self.parent.unselect_cell()
673
+
674
+ def mouseDragEvent(self, ev):
675
+ ev.ignore()
676
+ return
677
+
678
+ def hoverEvent(self, ev):
679
+ if self.parent.in_stroke:
680
+ if self.parent.in_stroke:
681
+ # continue stroke if not at start
682
+ self.drawAt(ev.pos())
683
+ if self.is_at_start(ev.pos()):
684
+ self.end_stroke()
685
+ else:
686
+ ev.acceptClicks(QtCore.Qt.RightButton)
687
+
688
+ def create_start(self, pos):
689
+ self.scatter = pg.ScatterPlotItem([pos.x()], [pos.y()], pxMode=False,
690
+ pen=pg.mkPen(color=(255, 0, 0),
691
+ width=self.parent.brush_size),
692
+ size=max(3 * 2,
693
+ self.parent.brush_size * 1.8 * 2),
694
+ brush=None)
695
+ self.parent.p0.addItem(self.scatter)
696
+
697
+ def is_at_start(self, pos):
698
+ thresh_out = max(6, self.parent.brush_size * 3)
699
+ thresh_in = max(3, self.parent.brush_size * 1.8)
700
+ # first check if you ever left the start
701
+ if len(self.parent.current_stroke) > 3:
702
+ stroke = np.array(self.parent.current_stroke)
703
+ dist = (((stroke[1:, 1:] -
704
+ stroke[:1, 1:][np.newaxis, :, :])**2).sum(axis=-1))**0.5
705
+ dist = dist.flatten()
706
+ has_left = (dist > thresh_out).nonzero()[0]
707
+ if len(has_left) > 0:
708
+ first_left = np.sort(has_left)[0]
709
+ has_returned = (dist[max(4, first_left + 1):] < thresh_in).sum()
710
+ if has_returned > 0:
711
+ return True
712
+ else:
713
+ return False
714
+ else:
715
+ return False
716
+
717
+ def end_stroke(self):
718
+ self.parent.p0.removeItem(self.scatter)
719
+ if not self.parent.stroke_appended:
720
+ self.parent.strokes.append(self.parent.current_stroke)
721
+ self.parent.stroke_appended = True
722
+ self.parent.current_stroke = np.array(self.parent.current_stroke)
723
+ ioutline = self.parent.current_stroke[:, 3] == 1
724
+ self.parent.current_point_set.append(
725
+ list(self.parent.current_stroke[ioutline]))
726
+ self.parent.current_stroke = []
727
+ if self.parent.autosave:
728
+ self.parent.add_set()
729
+ if len(self.parent.current_point_set) and len(
730
+ self.parent.current_point_set[0]) > 0 and self.parent.autosave:
731
+ self.parent.add_set()
732
+ self.parent.in_stroke = False
733
+
734
+ def tabletEvent(self, ev):
735
+ pass
736
+
737
+ def drawAt(self, pos, ev=None):
738
+ mask = self.strokemask
739
+ stroke = self.parent.current_stroke
740
+ pos = [int(pos.y()), int(pos.x())]
741
+ dk = self.drawKernel
742
+ kc = self.drawKernelCenter
743
+ sx = [0, dk.shape[0]]
744
+ sy = [0, dk.shape[1]]
745
+ tx = [pos[0] - kc[0], pos[0] - kc[0] + dk.shape[0]]
746
+ ty = [pos[1] - kc[1], pos[1] - kc[1] + dk.shape[1]]
747
+ kcent = kc.copy()
748
+ if tx[0] <= 0:
749
+ sx[0] = 0
750
+ sx[1] = kc[0] + 1
751
+ tx = sx
752
+ kcent[0] = 0
753
+ if ty[0] <= 0:
754
+ sy[0] = 0
755
+ sy[1] = kc[1] + 1
756
+ ty = sy
757
+ kcent[1] = 0
758
+ if tx[1] >= self.parent.Ly - 1:
759
+ sx[0] = dk.shape[0] - kc[0] - 1
760
+ sx[1] = dk.shape[0]
761
+ tx[0] = self.parent.Ly - kc[0] - 1
762
+ tx[1] = self.parent.Ly
763
+ kcent[0] = tx[1] - tx[0] - 1
764
+ if ty[1] >= self.parent.Lx - 1:
765
+ sy[0] = dk.shape[1] - kc[1] - 1
766
+ sy[1] = dk.shape[1]
767
+ ty[0] = self.parent.Lx - kc[1] - 1
768
+ ty[1] = self.parent.Lx
769
+ kcent[1] = ty[1] - ty[0] - 1
770
+
771
+ ts = (slice(tx[0], tx[1]), slice(ty[0], ty[1]))
772
+ ss = (slice(sx[0], sx[1]), slice(sy[0], sy[1]))
773
+ self.image[ts] = mask[ss]
774
+
775
+ for ky, y in enumerate(np.arange(ty[0], ty[1], 1, int)):
776
+ for kx, x in enumerate(np.arange(tx[0], tx[1], 1, int)):
777
+ iscent = np.logical_and(kx == kcent[0], ky == kcent[1])
778
+ stroke.append([self.parent.currentZ, x, y, iscent])
779
+ self.updateImage()
780
+
781
+ def setDrawKernel(self, kernel_size=3):
782
+ bs = kernel_size
783
+ kernel = np.ones((bs, bs), np.uint8)
784
+ self.drawKernel = kernel
785
+ self.drawKernelCenter = [
786
+ int(np.floor(kernel.shape[0] / 2)),
787
+ int(np.floor(kernel.shape[1] / 2))
788
+ ]
789
+ onmask = 255 * kernel[:, :, np.newaxis]
790
+ offmask = np.zeros((bs, bs, 1))
791
+ opamask = 100 * kernel[:, :, np.newaxis]
792
+ self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1)
793
+ self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1)
models/seg_post_model/cellpose/gui/guitrainhelpwindowtext.html ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <qt>
2
+ Check out this <a href="https://youtu.be/3Y1VKcxjNy4">video</a> to learn the process.
3
+ <ol>
4
+ <li>Drag and drop an image from a folder of images with a similar style (like similar cell types).</li>
5
+ <li>Run the built-in models on one of the images using the "model zoo" and find the one that works best for your
6
+ data. Make sure that if you have a nuclear channel you have selected it for CHAN2.
7
+ </li>
8
+ <li>Fix the labelling by drawing new ROIs (right-click) and deleting incorrect ones (CTRL+click). The GUI
9
+ autosaves any manual changes (but does not autosave after running the model, for that click CTRL+S). The
10
+ segmentation is saved in a "_seg.npy" file.
11
+ </li>
12
+ <li> Go to the "Models" menu in the File bar at the top and click "Train new model..." or use shortcut CTRL+T.
13
+ </li>
14
+ <li> Choose the pretrained model to start the training from (the model you used in #2), and type in the model
15
+ name that you want to use. The other parameters should work well in general for most data types. Then click
16
+ OK.
17
+ </li>
18
+ <li> The model will train (much faster if you have a GPU) and then auto-run on the next image in the folder.
19
+ Next you can repeat #3-#5 as many times as is necessary.
20
+ </li>
21
+ <li> The trained model is available to use in the future in the GUI in the "custom model" section and is saved
22
+ in your image folder.
23
+ </li>
24
+ </ol>
25
+ </qt>
models/seg_post_model/cellpose/gui/io.py ADDED
@@ -0,0 +1,634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import os, gc
5
+ import numpy as np
6
+ import cv2
7
+ import fastremap
8
+
9
+ from ..io import imread, imread_2D, imread_3D, imsave, outlines_to_text, add_model, remove_model, save_rois
10
+ from ..models import normalize_default, MODEL_DIR, MODEL_LIST_PATH, get_user_models
11
+ from ..utils import masks_to_outlines, outlines_list
12
+
13
+ try:
14
+ import qtpy
15
+ from qtpy.QtWidgets import QFileDialog
16
+ GUI = True
17
+ except:
18
+ GUI = False
19
+
20
+ try:
21
+ import matplotlib.pyplot as plt
22
+ MATPLOTLIB = True
23
+ except:
24
+ MATPLOTLIB = False
25
+
26
+
27
+ def _init_model_list(parent):
28
+ MODEL_DIR.mkdir(parents=True, exist_ok=True)
29
+ parent.model_list_path = MODEL_LIST_PATH
30
+ parent.model_strings = get_user_models()
31
+
32
+
33
+ def _add_model(parent, filename=None, load_model=True):
34
+ if filename is None:
35
+ name = QFileDialog.getOpenFileName(parent, "Add model to GUI")
36
+ filename = name[0]
37
+ add_model(filename)
38
+ fname = os.path.split(filename)[-1]
39
+ parent.ModelChooseC.addItems([fname])
40
+ parent.model_strings.append(fname)
41
+
42
+ for ind, model_string in enumerate(parent.model_strings[:-1]):
43
+ if model_string == fname:
44
+ _remove_model(parent, ind=ind + 1, verbose=False)
45
+
46
+ parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
47
+ if load_model:
48
+ parent.model_choose(custom=True)
49
+
50
+
51
+ def _remove_model(parent, ind=None, verbose=True):
52
+ if ind is None:
53
+ ind = parent.ModelChooseC.currentIndex()
54
+ if ind > 0:
55
+ ind -= 1
56
+ parent.ModelChooseC.removeItem(ind + 1)
57
+ del parent.model_strings[ind]
58
+ # remove model from txt path
59
+ modelstr = parent.ModelChooseC.currentText()
60
+ remove_model(modelstr)
61
+ if len(parent.model_strings) > 0:
62
+ parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
63
+ else:
64
+ parent.ModelChooseC.setCurrentIndex(0)
65
+ else:
66
+ print("ERROR: no model selected to delete")
67
+
68
+
69
+ def _get_train_set(image_names):
70
+ """ get training data and labels for images in current folder image_names"""
71
+ train_data, train_labels, train_files = [], [], []
72
+ restore = None
73
+ normalize_params = normalize_default
74
+ for image_name_full in image_names:
75
+ image_name = os.path.splitext(image_name_full)[0]
76
+ label_name = None
77
+ if os.path.exists(image_name + "_seg.npy"):
78
+ dat = np.load(image_name + "_seg.npy", allow_pickle=True).item()
79
+ masks = dat["masks"].squeeze()
80
+ if masks.ndim == 2:
81
+ fastremap.renumber(masks, in_place=True)
82
+ label_name = image_name + "_seg.npy"
83
+ else:
84
+ print(f"GUI_INFO: _seg.npy found for {image_name} but masks.ndim!=2")
85
+ if "img_restore" in dat:
86
+ data = dat["img_restore"].squeeze()
87
+ restore = dat["restore"]
88
+ else:
89
+ data = imread(image_name_full)
90
+ normalize_params = dat[
91
+ "normalize_params"] if "normalize_params" in dat else normalize_default
92
+ if label_name is not None:
93
+ train_files.append(image_name_full)
94
+ train_data.append(data)
95
+ train_labels.append(masks)
96
+ if restore:
97
+ print(f"GUI_INFO: using {restore} images (dat['img_restore'])")
98
+ return train_data, train_labels, train_files, restore, normalize_params
99
+
100
+
101
+ def _load_image(parent, filename=None, load_seg=True, load_3D=False):
102
+ """ load image with filename; if None, open QFileDialog
103
+ if image is grey change view to default to grey scale
104
+ """
105
+
106
+ if parent.load_3D:
107
+ load_3D = True
108
+
109
+ if filename is None:
110
+ name = QFileDialog.getOpenFileName(parent, "Load image")
111
+ filename = name[0]
112
+ if filename == "":
113
+ return
114
+ manual_file = os.path.splitext(filename)[0] + "_seg.npy"
115
+ load_mask = False
116
+ if load_seg:
117
+ if os.path.isfile(manual_file) and not parent.autoloadMasks.isChecked():
118
+ if filename is not None:
119
+ image = (imread_2D(filename) if not load_3D else
120
+ imread_3D(filename))
121
+ else:
122
+ image = None
123
+ _load_seg(parent, manual_file, image=image, image_file=filename,
124
+ load_3D=load_3D)
125
+ return
126
+ elif parent.autoloadMasks.isChecked():
127
+ mask_file = os.path.splitext(filename)[0] + "_masks" + os.path.splitext(
128
+ filename)[-1]
129
+ mask_file = os.path.splitext(filename)[
130
+ 0] + "_masks.tif" if not os.path.isfile(mask_file) else mask_file
131
+ load_mask = True if os.path.isfile(mask_file) else False
132
+ try:
133
+ print(f"GUI_INFO: loading image: {filename}")
134
+ if not load_3D:
135
+ image = imread_2D(filename)
136
+ else:
137
+ image = imread_3D(filename)
138
+ parent.loaded = True
139
+ except Exception as e:
140
+ print("ERROR: images not compatible")
141
+ print(f"ERROR: {e}")
142
+
143
+ if parent.loaded:
144
+ parent.reset()
145
+ parent.filename = filename
146
+ filename = os.path.split(parent.filename)[-1]
147
+ _initialize_images(parent, image, load_3D=load_3D)
148
+ parent.loaded = True
149
+ parent.enable_buttons()
150
+ if load_mask:
151
+ _load_masks(parent, filename=mask_file)
152
+
153
+ # check if gray and adjust viewer:
154
+ if len(np.unique(image[..., 1:])) == 1:
155
+ parent.color = 4
156
+ parent.RGBDropDown.setCurrentIndex(4) # gray
157
+ parent.update_plot()
158
+
159
+
160
+ def _initialize_images(parent, image, load_3D=False):
161
+ """ format image for GUI
162
+
163
+ assumes image is Z x W x H x C
164
+
165
+ """
166
+ load_3D = parent.load_3D if load_3D is False else load_3D
167
+
168
+ parent.stack = image
169
+ print(f"GUI_INFO: image shape: {image.shape}")
170
+ if load_3D:
171
+ parent.NZ = len(parent.stack)
172
+ parent.scroll.setMaximum(parent.NZ - 1)
173
+ else:
174
+ parent.NZ = 1
175
+ parent.stack = parent.stack[np.newaxis, ...]
176
+
177
+ img_min = image.min()
178
+ img_max = image.max()
179
+ parent.stack = parent.stack.astype(np.float32)
180
+ parent.stack -= img_min
181
+ if img_max > img_min + 1e-3:
182
+ parent.stack /= (img_max - img_min)
183
+ parent.stack *= 255
184
+
185
+ if load_3D:
186
+ print("GUI_INFO: converted to float and normalized values to 0.0->255.0")
187
+
188
+ del image
189
+ gc.collect()
190
+
191
+ parent.imask = 0
192
+ parent.Ly, parent.Lx = parent.stack.shape[-3:-1]
193
+ parent.Ly0, parent.Lx0 = parent.stack.shape[-3:-1]
194
+ parent.layerz = 255 * np.ones((parent.Ly, parent.Lx, 4), "uint8")
195
+ if hasattr(parent, "stack_filtered"):
196
+ parent.Lyr, parent.Lxr = parent.stack_filtered.shape[-3:-1]
197
+ elif parent.restore and "upsample" in parent.restore:
198
+ parent.Lyr, parent.Lxr = int(parent.Ly * parent.ratio), int(parent.Lx *
199
+ parent.ratio)
200
+ else:
201
+ parent.Lyr, parent.Lxr = parent.Ly, parent.Lx
202
+ parent.clear_all()
203
+
204
+ if not hasattr(parent, "stack_filtered") and parent.restore:
205
+ print("GUI_INFO: no 'img_restore' found, applying current settings")
206
+ parent.compute_restore()
207
+
208
+ if parent.autobtn.isChecked():
209
+ if parent.restore is None or parent.restore != "filter":
210
+ print(
211
+ "GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)"
212
+ )
213
+ parent.compute_saturation()
214
+ # elif len(parent.saturation) != parent.NZ:
215
+ # parent.saturation = []
216
+ # for r in range(3):
217
+ # parent.saturation.append([])
218
+ # for n in range(parent.NZ):
219
+ # parent.saturation[-1].append([0, 255])
220
+ # parent.sliders[r].setValue([0, 255])
221
+ parent.compute_scale()
222
+ parent.track_changes = []
223
+
224
+ if load_3D:
225
+ parent.currentZ = int(np.floor(parent.NZ / 2))
226
+ parent.scroll.setValue(parent.currentZ)
227
+ parent.zpos.setText(str(parent.currentZ))
228
+ else:
229
+ parent.currentZ = 0
230
+
231
+
232
+ def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False):
233
+ """ load *_seg.npy with filename; if None, open QFileDialog """
234
+ if filename is None:
235
+ name = QFileDialog.getOpenFileName(parent, "Load labelled data", filter="*.npy")
236
+ filename = name[0]
237
+ try:
238
+ dat = np.load(filename, allow_pickle=True).item()
239
+ # check if there are keys in filename
240
+ dat["outlines"]
241
+ parent.loaded = True
242
+ except:
243
+ parent.loaded = False
244
+ print("ERROR: not NPY")
245
+ return
246
+
247
+ parent.reset()
248
+ if image is None:
249
+ found_image = False
250
+ if "filename" in dat:
251
+ parent.filename = dat["filename"]
252
+ if os.path.isfile(parent.filename):
253
+ parent.filename = dat["filename"]
254
+ found_image = True
255
+ else:
256
+ imgname = os.path.split(parent.filename)[1]
257
+ root = os.path.split(filename)[0]
258
+ parent.filename = root + "/" + imgname
259
+ if os.path.isfile(parent.filename):
260
+ found_image = True
261
+ if found_image:
262
+ try:
263
+ print(parent.filename)
264
+ image = (imread_2D(parent.filename) if not load_3D else
265
+ imread_3D(parent.filename))
266
+ except:
267
+ parent.loaded = False
268
+ found_image = False
269
+ print("ERROR: cannot find image file, loading from npy")
270
+ if not found_image:
271
+ parent.filename = filename[:-8]
272
+ print(parent.filename)
273
+ if "img" in dat:
274
+ image = dat["img"]
275
+ else:
276
+ print("ERROR: no image file found and no image in npy")
277
+ return
278
+ else:
279
+ parent.filename = image_file
280
+
281
+ parent.restore = None
282
+ parent.ratio = 1.
283
+
284
+ if "normalize_params" in dat:
285
+ parent.set_normalize_params(dat["normalize_params"])
286
+
287
+ _initialize_images(parent, image, load_3D=load_3D)
288
+ print(parent.stack.shape)
289
+
290
+ if "outlines" in dat:
291
+ if isinstance(dat["outlines"], list):
292
+ # old way of saving files
293
+ dat["outlines"] = dat["outlines"][::-1]
294
+ for k, outline in enumerate(dat["outlines"]):
295
+ if "colors" in dat:
296
+ color = dat["colors"][k]
297
+ else:
298
+ col_rand = np.random.randint(1000)
299
+ color = parent.colormap[col_rand, :3]
300
+ median = parent.add_mask(points=outline, color=color)
301
+ if median is not None:
302
+ parent.cellcolors = np.append(parent.cellcolors,
303
+ color[np.newaxis, :], axis=0)
304
+ parent.ncells += 1
305
+ else:
306
+ if dat["masks"].min() == -1:
307
+ dat["masks"] += 1
308
+ dat["outlines"] += 1
309
+ parent.ncells.set(dat["masks"].max())
310
+ if "colors" in dat and len(dat["colors"]) == dat["masks"].max():
311
+ colors = dat["colors"]
312
+ else:
313
+ colors = parent.colormap[:parent.ncells.get(), :3]
314
+
315
+ _masks_to_gui(parent, dat["masks"], outlines=dat["outlines"], colors=colors)
316
+
317
+ parent.draw_layer()
318
+
319
+ if "manual_changes" in dat:
320
+ parent.track_changes = dat["manual_changes"]
321
+ print("GUI_INFO: loaded in previous changes")
322
+ if "zdraw" in dat:
323
+ parent.zdraw = dat["zdraw"]
324
+ else:
325
+ parent.zdraw = [None for n in range(parent.ncells.get())]
326
+ parent.loaded = True
327
+ else:
328
+ parent.clear_all()
329
+
330
+ parent.ismanual = np.zeros(parent.ncells.get(), bool)
331
+ if "ismanual" in dat:
332
+ if len(dat["ismanual"]) == parent.ncells:
333
+ parent.ismanual = dat["ismanual"]
334
+
335
+ if "current_channel" in dat:
336
+ parent.color = (dat["current_channel"] + 2) % 5
337
+ parent.RGBDropDown.setCurrentIndex(parent.color)
338
+
339
+ if "flows" in dat:
340
+ parent.flows = dat["flows"]
341
+ try:
342
+ if parent.flows[0].shape[-3] != dat["masks"].shape[-2]:
343
+ Ly, Lx = dat["masks"].shape[-2:]
344
+ for i in range(len(parent.flows)):
345
+ parent.flows[i] = cv2.resize(
346
+ parent.flows[i].squeeze(), (Lx, Ly),
347
+ interpolation=cv2.INTER_NEAREST)[np.newaxis, ...]
348
+ if parent.NZ == 1:
349
+ parent.recompute_masks = True
350
+ else:
351
+ parent.recompute_masks = False
352
+
353
+ except:
354
+ try:
355
+ if len(parent.flows[0]) > 0:
356
+ parent.flows = parent.flows[0]
357
+ except:
358
+ parent.flows = [[], [], [], [], [[]]]
359
+ parent.recompute_masks = False
360
+
361
+ parent.enable_buttons()
362
+ parent.update_layer()
363
+ del dat
364
+ gc.collect()
365
+
366
+
367
+ def _load_masks(parent, filename=None):
368
+ """ load zeros-based masks (0=no cell, 1=cell 1, ...) """
369
+ if filename is None:
370
+ name = QFileDialog.getOpenFileName(parent, "Load masks (PNG or TIFF)")
371
+ filename = name[0]
372
+ print(f"GUI_INFO: loading masks: {filename}")
373
+ masks = imread(filename)
374
+ outlines = None
375
+ if masks.ndim > 3:
376
+ # Z x nchannels x Ly x Lx
377
+ if masks.shape[-1] > 5:
378
+ parent.flows = list(np.transpose(masks[:, :, :, 2:], (3, 0, 1, 2)))
379
+ outlines = masks[..., 1]
380
+ masks = masks[..., 0]
381
+ else:
382
+ parent.flows = list(np.transpose(masks[:, :, :, 1:], (3, 0, 1, 2)))
383
+ masks = masks[..., 0]
384
+ elif masks.ndim == 3:
385
+ if masks.shape[-1] < 5:
386
+ masks = masks[np.newaxis, :, :, 0]
387
+ elif masks.ndim < 3:
388
+ masks = masks[np.newaxis, :, :]
389
+ # masks should be Z x Ly x Lx
390
+ if masks.shape[0] != parent.NZ:
391
+ print("ERROR: masks are not same depth (number of planes) as image stack")
392
+ return
393
+
394
+ _masks_to_gui(parent, masks, outlines)
395
+ if parent.ncells > 0:
396
+ parent.draw_layer()
397
+ parent.toggle_mask_ops()
398
+ del masks
399
+ gc.collect()
400
+ parent.update_layer()
401
+ parent.update_plot()
402
+
403
+
404
+ def _masks_to_gui(parent, masks, outlines=None, colors=None):
405
+ """ masks loaded into GUI """
406
+ # get unique values
407
+ shape = masks.shape
408
+ if len(fastremap.unique(masks)) != masks.max() + 1:
409
+ print("GUI_INFO: renumbering masks")
410
+ fastremap.renumber(masks, in_place=True)
411
+ outlines = None
412
+ masks = masks.reshape(shape)
413
+ if masks.ndim == 2:
414
+ outlines = None
415
+ masks = masks.astype(np.uint16) if masks.max() < 2**16 - 1 else masks.astype(
416
+ np.uint32)
417
+ if parent.restore and "upsample" in parent.restore:
418
+ parent.cellpix_resize = masks.copy()
419
+ parent.cellpix = parent.cellpix_resize.copy()
420
+ parent.cellpix_orig = cv2.resize(
421
+ masks.squeeze(), (parent.Lx0, parent.Ly0),
422
+ interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :]
423
+ parent.resize = True
424
+ else:
425
+ parent.cellpix = masks
426
+ if parent.cellpix.ndim == 2:
427
+ parent.cellpix = parent.cellpix[np.newaxis, :, :]
428
+ if parent.restore and "upsample" in parent.restore:
429
+ if parent.cellpix_resize.ndim == 2:
430
+ parent.cellpix_resize = parent.cellpix_resize[np.newaxis, :, :]
431
+ if parent.cellpix_orig.ndim == 2:
432
+ parent.cellpix_orig = parent.cellpix_orig[np.newaxis, :, :]
433
+
434
+ print(f"GUI_INFO: {masks.max()} masks found")
435
+
436
+ # get outlines
437
+ if outlines is None: # parent.outlinesOn
438
+ parent.outpix = np.zeros_like(parent.cellpix)
439
+ if parent.restore and "upsample" in parent.restore:
440
+ parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
441
+ for z in range(parent.NZ):
442
+ outlines = masks_to_outlines(parent.cellpix[z])
443
+ parent.outpix[z] = outlines * parent.cellpix[z]
444
+ if parent.restore and "upsample" in parent.restore:
445
+ outlines = masks_to_outlines(parent.cellpix_orig[z])
446
+ parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
447
+ if z % 50 == 0 and parent.NZ > 1:
448
+ print("GUI_INFO: plane %d outlines processed" % z)
449
+ if parent.restore and "upsample" in parent.restore:
450
+ parent.outpix_resize = parent.outpix.copy()
451
+ else:
452
+ parent.outpix = outlines
453
+ if parent.restore and "upsample" in parent.restore:
454
+ parent.outpix_resize = parent.outpix.copy()
455
+ parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
456
+ for z in range(parent.NZ):
457
+ outlines = masks_to_outlines(parent.cellpix_orig[z])
458
+ parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
459
+ if z % 50 == 0 and parent.NZ > 1:
460
+ print("GUI_INFO: plane %d outlines processed" % z)
461
+
462
+ if parent.outpix.ndim == 2:
463
+ parent.outpix = parent.outpix[np.newaxis, :, :]
464
+ if parent.restore and "upsample" in parent.restore:
465
+ if parent.outpix_resize.ndim == 2:
466
+ parent.outpix_resize = parent.outpix_resize[np.newaxis, :, :]
467
+ if parent.outpix_orig.ndim == 2:
468
+ parent.outpix_orig = parent.outpix_orig[np.newaxis, :, :]
469
+
470
+ parent.ncells.set(parent.cellpix.max())
471
+ colors = parent.colormap[:parent.ncells.get(), :3] if colors is None else colors
472
+ print("GUI_INFO: creating cellcolors and drawing masks")
473
+ parent.cellcolors = np.concatenate((np.array([[255, 255, 255]]), colors),
474
+ axis=0).astype(np.uint8)
475
+ if parent.ncells > 0:
476
+ parent.draw_layer()
477
+ parent.toggle_mask_ops()
478
+ parent.ismanual = np.zeros(parent.ncells.get(), bool)
479
+ parent.zdraw = list(-1 * np.ones(parent.ncells.get(), np.int16))
480
+
481
+ if hasattr(parent, "stack_filtered"):
482
+ parent.ViewDropDown.setCurrentIndex(parent.ViewDropDown.count() - 1)
483
+ print("set denoised/filtered view")
484
+ else:
485
+ parent.ViewDropDown.setCurrentIndex(0)
486
+
487
+
488
+ def _save_png(parent):
489
+ """ save masks to png or tiff (if 3D) """
490
+ filename = parent.filename
491
+ base = os.path.splitext(filename)[0]
492
+ if parent.NZ == 1:
493
+ if parent.cellpix[0].max() > 65534:
494
+ print("GUI_INFO: saving 2D masks to tif (too many masks for PNG)")
495
+ imsave(base + "_cp_masks.tif", parent.cellpix[0])
496
+ else:
497
+ print("GUI_INFO: saving 2D masks to png")
498
+ imsave(base + "_cp_masks.png", parent.cellpix[0].astype(np.uint16))
499
+ else:
500
+ print("GUI_INFO: saving 3D masks to tiff")
501
+ imsave(base + "_cp_masks.tif", parent.cellpix)
502
+
503
+
504
+ def _save_flows(parent):
505
+ """ save flows and cellprob to tiff """
506
+ filename = parent.filename
507
+ base = os.path.splitext(filename)[0]
508
+ print("GUI_INFO: saving flows and cellprob to tiff")
509
+ if len(parent.flows) > 0:
510
+ imsave(base + "_cp_cellprob.tif", parent.flows[1])
511
+ for i in range(3):
512
+ imsave(base + f"_cp_flows_{i}.tif", parent.flows[0][..., i])
513
+ if len(parent.flows) > 2:
514
+ imsave(base + "_cp_flows.tif", parent.flows[2])
515
+ print("GUI_INFO: saved flows and cellprob")
516
+ else:
517
+ print("ERROR: no flows or cellprob found")
518
+
519
+
520
+ def _save_rois(parent):
521
+ """ save masks as rois in .zip file for ImageJ """
522
+ filename = parent.filename
523
+ if parent.NZ == 1:
524
+ print(
525
+ f"GUI_INFO: saving {parent.cellpix[0].max()} ImageJ ROIs to .zip archive.")
526
+ save_rois(parent.cellpix[0], parent.filename)
527
+ else:
528
+ print("ERROR: cannot save 3D outlines")
529
+
530
+
531
+ def _save_outlines(parent):
532
+ filename = parent.filename
533
+ base = os.path.splitext(filename)[0]
534
+ if parent.NZ == 1:
535
+ print(
536
+ "GUI_INFO: saving 2D outlines to text file, see docs for info to load into ImageJ"
537
+ )
538
+ outlines = outlines_list(parent.cellpix[0])
539
+ outlines_to_text(base, outlines)
540
+ else:
541
+ print("ERROR: cannot save 3D outlines")
542
+
543
+
544
+ def _save_sets_with_check(parent):
545
+ """ Save masks and update *_seg.npy file. Use this function when saving should be optional
546
+ based on the disableAutosave checkbox. Otherwise, use _save_sets """
547
+ if not parent.disableAutosave.isChecked():
548
+ _save_sets(parent)
549
+
550
+
551
+ def _save_sets(parent):
552
+ """ save masks to *_seg.npy. This function should be used when saving
553
+ is forced, e.g. when clicking the save button. Otherwise, use _save_sets_with_check
554
+ """
555
+ filename = parent.filename
556
+ base = os.path.splitext(filename)[0]
557
+ flow_threshold = parent.segmentation_settings.flow_threshold
558
+ cellprob_threshold = parent.segmentation_settings.cellprob_threshold
559
+
560
+ if parent.NZ > 1:
561
+ dat = {
562
+ "outlines":
563
+ parent.outpix,
564
+ "colors":
565
+ parent.cellcolors[1:],
566
+ "masks":
567
+ parent.cellpix,
568
+ "current_channel": (parent.color - 2) % 5,
569
+ "filename":
570
+ parent.filename,
571
+ "flows":
572
+ parent.flows,
573
+ "zdraw":
574
+ parent.zdraw,
575
+ "model_path":
576
+ parent.current_model_path
577
+ if hasattr(parent, "current_model_path") else 0,
578
+ "flow_threshold":
579
+ flow_threshold,
580
+ "cellprob_threshold":
581
+ cellprob_threshold,
582
+ "normalize_params":
583
+ parent.get_normalize_params(),
584
+ "restore":
585
+ parent.restore,
586
+ "ratio":
587
+ parent.ratio,
588
+ "diameter":
589
+ parent.segmentation_settings.diameter
590
+ }
591
+ if parent.restore is not None:
592
+ dat["img_restore"] = parent.stack_filtered
593
+ else:
594
+ dat = {
595
+ "outlines":
596
+ parent.outpix.squeeze() if parent.restore is None or
597
+ not "upsample" in parent.restore else parent.outpix_resize.squeeze(),
598
+ "colors":
599
+ parent.cellcolors[1:],
600
+ "masks":
601
+ parent.cellpix.squeeze() if parent.restore is None or
602
+ not "upsample" in parent.restore else parent.cellpix_resize.squeeze(),
603
+ "filename":
604
+ parent.filename,
605
+ "flows":
606
+ parent.flows,
607
+ "ismanual":
608
+ parent.ismanual,
609
+ "manual_changes":
610
+ parent.track_changes,
611
+ "model_path":
612
+ parent.current_model_path
613
+ if hasattr(parent, "current_model_path") else 0,
614
+ "flow_threshold":
615
+ flow_threshold,
616
+ "cellprob_threshold":
617
+ cellprob_threshold,
618
+ "normalize_params":
619
+ parent.get_normalize_params(),
620
+ "restore":
621
+ parent.restore,
622
+ "ratio":
623
+ parent.ratio,
624
+ "diameter":
625
+ parent.segmentation_settings.diameter
626
+ }
627
+ if parent.restore is not None:
628
+ dat["img_restore"] = parent.stack_filtered
629
+ try:
630
+ np.save(base + "_seg.npy", dat)
631
+ print("GUI_INFO: %d ROIs saved to %s" % (parent.ncells.get(), base + "_seg.npy"))
632
+ except Exception as e:
633
+ print(f"ERROR: {e}")
634
+ del dat
models/seg_post_model/cellpose/gui/make_train.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, argparse
2
+ import numpy as np
3
+ from cellpose import io, transforms
4
+
5
+
6
+ def main():
7
+ parser = argparse.ArgumentParser(description='Make slices of XYZ image data for training. Assumes image is ZXYC unless specified otherwise using --channel_axis and --z_axis')
8
+
9
+ input_img_args = parser.add_argument_group("input image arguments")
10
+ input_img_args.add_argument('--dir', default=[], type=str,
11
+ help='folder containing data to run or train on.')
12
+ input_img_args.add_argument(
13
+ '--image_path', default=[], type=str, help=
14
+ 'if given and --dir not given, run on single image instead of folder (cannot train with this option)'
15
+ )
16
+ input_img_args.add_argument(
17
+ '--look_one_level_down', action='store_true',
18
+ help='run processing on all subdirectories of current folder')
19
+ input_img_args.add_argument('--img_filter', default=[], type=str,
20
+ help='end string for images to run on')
21
+ input_img_args.add_argument(
22
+ '--channel_axis', default=-1, type=int,
23
+ help='axis of image which corresponds to image channels')
24
+ input_img_args.add_argument('--z_axis', default=0, type=int,
25
+ help='axis of image which corresponds to Z dimension')
26
+ input_img_args.add_argument(
27
+ '--chan', default=0, type=int, help=
28
+ 'Deprecated')
29
+ input_img_args.add_argument(
30
+ '--chan2', default=0, type=int, help=
31
+ 'Deprecated'
32
+ )
33
+ input_img_args.add_argument('--invert', action='store_true',
34
+ help='invert grayscale channel')
35
+ input_img_args.add_argument(
36
+ '--all_channels', action='store_true', help=
37
+ 'deprecated')
38
+ input_img_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
39
+ help="anisotropy of volume in 3D")
40
+
41
+
42
+ # algorithm settings
43
+ algorithm_args = parser.add_argument_group("algorithm arguments")
44
+ algorithm_args.add_argument('--sharpen_radius', required=False, default=0.0,
45
+ type=float, help='high-pass filtering radius. Default: %(default)s')
46
+ algorithm_args.add_argument('--tile_norm', required=False, default=0, type=int,
47
+ help='tile normalization block size. Default: %(default)s')
48
+ algorithm_args.add_argument('--nimg_per_tif', required=False, default=10, type=int,
49
+ help='number of crops in XY to save per tiff. Default: %(default)s')
50
+ algorithm_args.add_argument('--crop_size', required=False, default=512, type=int,
51
+ help='size of random crop to save. Default: %(default)s')
52
+
53
+ args = parser.parse_args()
54
+
55
+ # find images
56
+ if len(args.img_filter) > 0:
57
+ imf = args.img_filter
58
+ else:
59
+ imf = None
60
+
61
+ if len(args.dir) > 0:
62
+ image_names = io.get_image_files(args.dir, "_masks", imf=imf,
63
+ look_one_level_down=args.look_one_level_down)
64
+ dirname = args.dir
65
+ else:
66
+ if os.path.exists(args.image_path):
67
+ image_names = [args.image_path]
68
+ dirname = os.path.split(args.image_path)[0]
69
+ else:
70
+ raise ValueError(f"ERROR: no file found at {args.image_path}")
71
+
72
+ np.random.seed(0)
73
+ nimg_per_tif = args.nimg_per_tif
74
+ crop_size = args.crop_size
75
+ os.makedirs(os.path.join(dirname, 'train/'), exist_ok=True)
76
+ pm = [(0, 1, 2, 3), (2, 0, 1, 3), (1, 0, 2, 3)]
77
+ npm = ["YX", "ZY", "ZX"]
78
+ for name in image_names:
79
+ name0 = os.path.splitext(os.path.split(name)[-1])[0]
80
+ img0 = io.imread_3D(name)
81
+ try:
82
+ img0 = transforms.convert_image(img0, channel_axis=args.channel_axis,
83
+ z_axis=args.z_axis, do_3D=True)
84
+ except ValueError:
85
+ print('Error converting image. Did you provide the correct --channel_axis and --z_axis ?')
86
+
87
+ for p in range(3):
88
+ img = img0.transpose(pm[p]).copy()
89
+ print(npm[p], img[0].shape)
90
+ Ly, Lx = img.shape[1:3]
91
+ imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]]
92
+ if args.anisotropy > 1.0 and p > 0:
93
+ imgs = transforms.resize_image(imgs, Ly=int(args.anisotropy * Ly), Lx=Lx)
94
+ for k, img in enumerate(imgs):
95
+ if args.tile_norm:
96
+ img = transforms.normalize99_tile(img, blocksize=args.tile_norm)
97
+ if args.sharpen_radius:
98
+ img = transforms.smooth_sharpen_img(img,
99
+ sharpen_radius=args.sharpen_radius)
100
+ ly = 0 if Ly - crop_size <= 0 else np.random.randint(0, Ly - crop_size)
101
+ lx = 0 if Lx - crop_size <= 0 else np.random.randint(0, Lx - crop_size)
102
+ io.imsave(os.path.join(dirname, f'train/{name0}_{npm[p]}_{k}.tif'),
103
+ img[ly:ly + args.crop_size, lx:lx + args.crop_size].squeeze())
104
+
105
+
106
+ if __name__ == '__main__':
107
+ main()
models/seg_post_model/cellpose/gui/menus.py ADDED
@@ -0,0 +1,145 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ from qtpy.QtWidgets import QAction
5
+ from . import io
6
+
7
+
8
+ def mainmenu(parent):
9
+ main_menu = parent.menuBar()
10
+ file_menu = main_menu.addMenu("&File")
11
+ # load processed data
12
+ loadImg = QAction("&Load image (*.tif, *.png, *.jpg)", parent)
13
+ loadImg.setShortcut("Ctrl+L")
14
+ loadImg.triggered.connect(lambda: io._load_image(parent))
15
+ file_menu.addAction(loadImg)
16
+
17
+ parent.autoloadMasks = QAction("Autoload masks from _masks.tif file", parent,
18
+ checkable=True)
19
+ parent.autoloadMasks.setChecked(False)
20
+ file_menu.addAction(parent.autoloadMasks)
21
+
22
+ parent.disableAutosave = QAction("Disable autosave _seg.npy file", parent,
23
+ checkable=True)
24
+ parent.disableAutosave.setChecked(False)
25
+ file_menu.addAction(parent.disableAutosave)
26
+
27
+ parent.loadMasks = QAction("Load &masks (*.tif, *.png, *.jpg)", parent)
28
+ parent.loadMasks.setShortcut("Ctrl+M")
29
+ parent.loadMasks.triggered.connect(lambda: io._load_masks(parent))
30
+ file_menu.addAction(parent.loadMasks)
31
+ parent.loadMasks.setEnabled(False)
32
+
33
+ loadManual = QAction("Load &processed/labelled image (*_seg.npy)", parent)
34
+ loadManual.setShortcut("Ctrl+P")
35
+ loadManual.triggered.connect(lambda: io._load_seg(parent))
36
+ file_menu.addAction(loadManual)
37
+
38
+ parent.saveSet = QAction("&Save masks and image (as *_seg.npy)", parent)
39
+ parent.saveSet.setShortcut("Ctrl+S")
40
+ parent.saveSet.triggered.connect(lambda: io._save_sets(parent))
41
+ file_menu.addAction(parent.saveSet)
42
+ parent.saveSet.setEnabled(False)
43
+
44
+ parent.savePNG = QAction("Save masks as P&NG/tif", parent)
45
+ parent.savePNG.setShortcut("Ctrl+N")
46
+ parent.savePNG.triggered.connect(lambda: io._save_png(parent))
47
+ file_menu.addAction(parent.savePNG)
48
+ parent.savePNG.setEnabled(False)
49
+
50
+ parent.saveOutlines = QAction("Save &Outlines as text for imageJ", parent)
51
+ parent.saveOutlines.setShortcut("Ctrl+O")
52
+ parent.saveOutlines.triggered.connect(lambda: io._save_outlines(parent))
53
+ file_menu.addAction(parent.saveOutlines)
54
+ parent.saveOutlines.setEnabled(False)
55
+
56
+ parent.saveROIs = QAction("Save outlines as .zip archive of &ROI files for ImageJ",
57
+ parent)
58
+ parent.saveROIs.setShortcut("Ctrl+R")
59
+ parent.saveROIs.triggered.connect(lambda: io._save_rois(parent))
60
+ file_menu.addAction(parent.saveROIs)
61
+ parent.saveROIs.setEnabled(False)
62
+
63
+ parent.saveFlows = QAction("Save &Flows and cellprob as tif", parent)
64
+ parent.saveFlows.setShortcut("Ctrl+F")
65
+ parent.saveFlows.triggered.connect(lambda: io._save_flows(parent))
66
+ file_menu.addAction(parent.saveFlows)
67
+ parent.saveFlows.setEnabled(False)
68
+
69
+
70
+ def editmenu(parent):
71
+ main_menu = parent.menuBar()
72
+ edit_menu = main_menu.addMenu("&Edit")
73
+ parent.undo = QAction("Undo previous mask/trace", parent)
74
+ parent.undo.setShortcut("Ctrl+Z")
75
+ parent.undo.triggered.connect(parent.undo_action)
76
+ parent.undo.setEnabled(False)
77
+ edit_menu.addAction(parent.undo)
78
+
79
+ parent.redo = QAction("Undo remove mask", parent)
80
+ parent.redo.setShortcut("Ctrl+Y")
81
+ parent.redo.triggered.connect(parent.undo_remove_action)
82
+ parent.redo.setEnabled(False)
83
+ edit_menu.addAction(parent.redo)
84
+
85
+ parent.ClearButton = QAction("Clear all masks", parent)
86
+ parent.ClearButton.setShortcut("Ctrl+0")
87
+ parent.ClearButton.triggered.connect(parent.clear_all)
88
+ parent.ClearButton.setEnabled(False)
89
+ edit_menu.addAction(parent.ClearButton)
90
+
91
+ parent.remcell = QAction("Remove selected cell (Ctrl+CLICK)", parent)
92
+ parent.remcell.setShortcut("Ctrl+Click")
93
+ parent.remcell.triggered.connect(parent.remove_action)
94
+ parent.remcell.setEnabled(False)
95
+ edit_menu.addAction(parent.remcell)
96
+
97
+ parent.mergecell = QAction("FYI: Merge cells by Alt+Click", parent)
98
+ parent.mergecell.setEnabled(False)
99
+ edit_menu.addAction(parent.mergecell)
100
+
101
+
102
+ def modelmenu(parent):
103
+ main_menu = parent.menuBar()
104
+ io._init_model_list(parent)
105
+ model_menu = main_menu.addMenu("&Models")
106
+ parent.addmodel = QAction("Add custom torch model to GUI", parent)
107
+ #parent.addmodel.setShortcut("Ctrl+A")
108
+ parent.addmodel.triggered.connect(parent.add_model)
109
+ parent.addmodel.setEnabled(True)
110
+ model_menu.addAction(parent.addmodel)
111
+
112
+ parent.removemodel = QAction("Remove selected custom model from GUI", parent)
113
+ #parent.removemodel.setShortcut("Ctrl+R")
114
+ parent.removemodel.triggered.connect(parent.remove_model)
115
+ parent.removemodel.setEnabled(True)
116
+ model_menu.addAction(parent.removemodel)
117
+
118
+ parent.newmodel = QAction("&Train new model with image+masks in folder", parent)
119
+ parent.newmodel.setShortcut("Ctrl+T")
120
+ parent.newmodel.triggered.connect(parent.new_model)
121
+ parent.newmodel.setEnabled(False)
122
+ model_menu.addAction(parent.newmodel)
123
+
124
+ openTrainHelp = QAction("Training instructions", parent)
125
+ openTrainHelp.triggered.connect(parent.train_help_window)
126
+ model_menu.addAction(openTrainHelp)
127
+
128
+
129
+ def helpmenu(parent):
130
+ main_menu = parent.menuBar()
131
+ help_menu = main_menu.addMenu("&Help")
132
+
133
+ openHelp = QAction("&Help with GUI", parent)
134
+ openHelp.setShortcut("Ctrl+H")
135
+ openHelp.triggered.connect(parent.help_window)
136
+ help_menu.addAction(openHelp)
137
+
138
+ openGUI = QAction("&GUI layout", parent)
139
+ openGUI.setShortcut("Ctrl+G")
140
+ openGUI.triggered.connect(parent.gui_window)
141
+ help_menu.addAction(openGUI)
142
+
143
+ openTrainHelp = QAction("Training instructions", parent)
144
+ openTrainHelp.triggered.connect(parent.train_help_window)
145
+ help_menu.addAction(openTrainHelp)
models/seg_post_model/cellpose/io.py ADDED
@@ -0,0 +1,816 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import os, warnings, glob, shutil
5
+ from natsort import natsorted
6
+ import numpy as np
7
+ import cv2
8
+ import tifffile
9
+ import logging, pathlib, sys
10
+ from tqdm import tqdm
11
+ from pathlib import Path
12
+ import re
13
+ from .version import version_str
14
+ from roifile import ImagejRoi, roiwrite
15
+
16
+ try:
17
+ from qtpy import QtGui, QtCore, Qt, QtWidgets
18
+ from qtpy.QtWidgets import QMessageBox
19
+ GUI = True
20
+ except:
21
+ GUI = False
22
+
23
+ try:
24
+ import matplotlib.pyplot as plt
25
+ MATPLOTLIB = True
26
+ except:
27
+ MATPLOTLIB = False
28
+
29
+ try:
30
+ import nd2
31
+ ND2 = True
32
+ except:
33
+ ND2 = False
34
+
35
+ try:
36
+ import nrrd
37
+ NRRD = True
38
+ except:
39
+ NRRD = False
40
+
41
+ try:
42
+ from google.cloud import storage
43
+ SERVER_UPLOAD = True
44
+ except:
45
+ SERVER_UPLOAD = False
46
+
47
+ io_logger = logging.getLogger(__name__)
48
+
49
+ def logger_setup(cp_path=".cellpose", logfile_name="run.log", stdout_file_replacement=None):
50
+ cp_dir = pathlib.Path.home().joinpath(cp_path)
51
+ cp_dir.mkdir(exist_ok=True)
52
+ log_file = cp_dir.joinpath(logfile_name)
53
+ try:
54
+ log_file.unlink()
55
+ except:
56
+ print('creating new log file')
57
+ handlers = [logging.FileHandler(log_file),]
58
+ if stdout_file_replacement is not None:
59
+ handlers.append(logging.FileHandler(stdout_file_replacement))
60
+ else:
61
+ handlers.append(logging.StreamHandler(sys.stdout))
62
+ logging.basicConfig(
63
+ level=logging.INFO,
64
+ format="%(asctime)s [%(levelname)s] %(message)s",
65
+ handlers=handlers,
66
+ force=True
67
+ )
68
+ logger = logging.getLogger(__name__)
69
+ logger.info(f"WRITING LOG OUTPUT TO {log_file}")
70
+ logger.info(version_str)
71
+
72
+ return logger, log_file
73
+
74
+
75
+ from . import utils, plot, transforms
76
+
77
+ # helper function to check for a path; if it doesn't exist, make it
78
+ def check_dir(path):
79
+ if not os.path.isdir(path):
80
+ os.mkdir(path)
81
+
82
+
83
+ def outlines_to_text(base, outlines):
84
+ with open(base + "_cp_outlines.txt", "w") as f:
85
+ for o in outlines:
86
+ xy = list(o.flatten())
87
+ xy_str = ",".join(map(str, xy))
88
+ f.write(xy_str)
89
+ f.write("\n")
90
+
91
+
92
+ def load_dax(filename):
93
+ ### modified from ZhuangLab github:
94
+ ### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156
95
+
96
+ inf_filename = os.path.splitext(filename)[0] + ".inf"
97
+ if not os.path.exists(inf_filename):
98
+ io_logger.critical(
99
+ f"ERROR: no inf file found for dax file {filename}, cannot load dax without it"
100
+ )
101
+ return None
102
+
103
+ ### get metadata
104
+ image_height, image_width = None, None
105
+ # extract the movie information from the associated inf file
106
+ size_re = re.compile(r"frame dimensions = ([\d]+) x ([\d]+)")
107
+ length_re = re.compile(r"number of frames = ([\d]+)")
108
+ endian_re = re.compile(r" (big|little) endian")
109
+
110
+ with open(inf_filename, "r") as inf_file:
111
+ lines = inf_file.read().split("\n")
112
+ for line in lines:
113
+ m = size_re.match(line)
114
+ if m:
115
+ image_height = int(m.group(2))
116
+ image_width = int(m.group(1))
117
+ m = length_re.match(line)
118
+ if m:
119
+ number_frames = int(m.group(1))
120
+ m = endian_re.search(line)
121
+ if m:
122
+ if m.group(1) == "big":
123
+ bigendian = 1
124
+ else:
125
+ bigendian = 0
126
+ # set defaults, warn the user that they couldn"t be determined from the inf file.
127
+ if not image_height:
128
+ io_logger.warning("could not determine dax image size, assuming 256x256")
129
+ image_height = 256
130
+ image_width = 256
131
+
132
+ ### load image
133
+ img = np.memmap(filename, dtype="uint16",
134
+ shape=(number_frames, image_height, image_width))
135
+ if bigendian:
136
+ img = img.byteswap()
137
+ img = np.array(img)
138
+
139
+ return img
140
+
141
+
142
+ def imread(filename):
143
+ """
144
+ Read in an image file with tif or image file type supported by cv2.
145
+
146
+ Args:
147
+ filename (str): The path to the image file.
148
+
149
+ Returns:
150
+ numpy.ndarray: The image data as a NumPy array.
151
+
152
+ Raises:
153
+ None
154
+
155
+ Raises an error if the image file format is not supported.
156
+
157
+ Examples:
158
+ >>> img = imread("image.tif")
159
+ """
160
+ # ensure that extension check is not case sensitive
161
+ ext = os.path.splitext(filename)[-1].lower()
162
+ if ext == ".tif" or ext == ".tiff" or ext == ".flex":
163
+ with tifffile.TiffFile(filename) as tif:
164
+ ltif = len(tif.pages)
165
+ try:
166
+ full_shape = tif.shaped_metadata[0]["shape"]
167
+ except:
168
+ try:
169
+ page = tif.series[0][0]
170
+ full_shape = tif.series[0].shape
171
+ except:
172
+ ltif = 0
173
+ if ltif < 10:
174
+ img = tif.asarray()
175
+ else:
176
+ page = tif.series[0][0]
177
+ shape, dtype = page.shape, page.dtype
178
+ ltif = int(np.prod(full_shape) / np.prod(shape))
179
+ io_logger.info(f"reading tiff with {ltif} planes")
180
+ img = np.zeros((ltif, *shape), dtype=dtype)
181
+ for i, page in enumerate(tqdm(tif.series[0])):
182
+ img[i] = page.asarray()
183
+ img = img.reshape(full_shape)
184
+ return img
185
+ elif ext == ".dax":
186
+ img = load_dax(filename)
187
+ return img
188
+ elif ext == ".nd2":
189
+ if not ND2:
190
+ io_logger.critical("ERROR: need to 'pip install nd2' to load in .nd2 file")
191
+ return None
192
+ elif ext == ".nrrd":
193
+ if not NRRD:
194
+ io_logger.critical(
195
+ "ERROR: need to 'pip install pynrrd' to load in .nrrd file")
196
+ return None
197
+ else:
198
+ img, metadata = nrrd.read(filename)
199
+ if img.ndim == 3:
200
+ img = img.transpose(2, 0, 1)
201
+ return img
202
+ elif ext != ".npy":
203
+ try:
204
+ img = cv2.imread(filename, -1) #cv2.LOAD_IMAGE_ANYDEPTH)
205
+ if img.ndim > 2:
206
+ img = img[..., [2, 1, 0]]
207
+ return img
208
+ except Exception as e:
209
+ io_logger.critical("ERROR: could not read file, %s" % e)
210
+ return None
211
+ else:
212
+ try:
213
+ dat = np.load(filename, allow_pickle=True).item()
214
+ masks = dat["masks"]
215
+ return masks
216
+ except Exception as e:
217
+ io_logger.critical("ERROR: could not read masks from file, %s" % e)
218
+ return None
219
+
220
+
221
+ def imread_2D(img_file):
222
+ """
223
+ Read in a 2D image file and convert it to a 3-channel image. Attempts to do this for multi-channel and grayscale images.
224
+ If the image has more than 3 channels, only the first 3 channels are kept.
225
+
226
+ Args:
227
+ img_file (str): The path to the image file.
228
+
229
+ Returns:
230
+ img_out (numpy.ndarray): The 3-channel image data as a NumPy array.
231
+ """
232
+ img = imread(img_file)
233
+ return transforms.convert_image(img, do_3D=False)
234
+
235
+
236
+ def imread_3D(img_file):
237
+ """
238
+ Read in a 3D image file and convert it to have a channel axis last automatically. Attempts to do this for multi-channel and grayscale images.
239
+
240
+ If multichannel image, the channel axis is assumed to be the smallest dimension, and the z axis is the next smallest dimension.
241
+ Use `cellpose.io.imread()` to load the full image without selecting the z and channel axes.
242
+
243
+ Args:
244
+ img_file (str): The path to the image file.
245
+
246
+ Returns:
247
+ img_out (numpy.ndarray): The image data as a NumPy array.
248
+ """
249
+ img = imread(img_file)
250
+
251
+ dimension_lengths = list(img.shape)
252
+
253
+ # grayscale images:
254
+ if img.ndim == 3:
255
+ channel_axis = None
256
+ # guess at z axis:
257
+ z_axis = np.argmin(dimension_lengths)
258
+
259
+ elif img.ndim == 4:
260
+ # guess at channel axis:
261
+ channel_axis = np.argmin(dimension_lengths)
262
+
263
+ # guess at z axis:
264
+ # set channel axis to max so argmin works:
265
+ dimension_lengths[channel_axis] = max(dimension_lengths)
266
+ z_axis = np.argmin(dimension_lengths)
267
+
268
+ else:
269
+ raise ValueError(f'image shape error, 3D image must 3 or 4 dimensional. Number of dimensions: {img.ndim}')
270
+
271
+ try:
272
+ return transforms.convert_image(img, channel_axis=channel_axis, z_axis=z_axis, do_3D=True)
273
+ except Exception as e:
274
+ io_logger.critical("ERROR: could not read file, %s" % e)
275
+ io_logger.critical("ERROR: Guessed z_axis: %s, channel_axis: %s" % (z_axis, channel_axis))
276
+ return None
277
+
278
+ def remove_model(filename, delete=False):
279
+ """ remove model from .cellpose custom model list """
280
+ filename = os.path.split(filename)[-1]
281
+ from . import models
282
+ model_strings = models.get_user_models()
283
+ if len(model_strings) > 0:
284
+ with open(models.MODEL_LIST_PATH, "w") as textfile:
285
+ for fname in model_strings:
286
+ textfile.write(fname + "\n")
287
+ else:
288
+ # write empty file
289
+ textfile = open(models.MODEL_LIST_PATH, "w")
290
+ textfile.close()
291
+ print(f"{filename} removed from custom model list")
292
+ if delete:
293
+ os.remove(os.fspath(models.MODEL_DIR.joinpath(fname)))
294
+ print("model deleted")
295
+
296
+
297
+ def add_model(filename):
298
+ """ add model to .cellpose models folder to use with GUI or CLI """
299
+ from . import models
300
+ fname = os.path.split(filename)[-1]
301
+ try:
302
+ shutil.copyfile(filename, os.fspath(models.MODEL_DIR.joinpath(fname)))
303
+ except shutil.SameFileError:
304
+ pass
305
+ print(f"{filename} copied to models folder {os.fspath(models.MODEL_DIR)}")
306
+ if fname not in models.get_user_models():
307
+ with open(models.MODEL_LIST_PATH, "a") as textfile:
308
+ textfile.write(fname + "\n")
309
+
310
+
311
+ def imsave(filename, arr):
312
+ """
313
+ Saves an image array to a file.
314
+
315
+ Args:
316
+ filename (str): The name of the file to save the image to.
317
+ arr (numpy.ndarray): The image array to be saved.
318
+
319
+ Returns:
320
+ None
321
+ """
322
+ ext = os.path.splitext(filename)[-1].lower()
323
+ if ext == ".tif" or ext == ".tiff":
324
+ tifffile.imwrite(filename, data=arr, compression="zlib")
325
+ else:
326
+ if len(arr.shape) > 2:
327
+ arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
328
+ cv2.imwrite(filename, arr)
329
+
330
+
331
+ def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False):
332
+ """
333
+ Finds all images in a folder and its subfolders (if specified) with the given file extensions.
334
+
335
+ Args:
336
+ folder (str): The path to the folder to search for images.
337
+ mask_filter (str): The filter for mask files.
338
+ imf (str, optional): The additional filter for image files. Defaults to None.
339
+ look_one_level_down (bool, optional): Whether to search for images in subfolders. Defaults to False.
340
+
341
+ Returns:
342
+ list: A list of image file paths.
343
+
344
+ Raises:
345
+ ValueError: If no files are found in the specified folder.
346
+ ValueError: If no images are found in the specified folder with the supported file extensions.
347
+ ValueError: If no images are found in the specified folder without the mask or flow file endings.
348
+ """
349
+ mask_filters = ["_cp_output", "_flows", "_flows_0", "_flows_1",
350
+ "_flows_2", "_cellprob", "_masks", mask_filter]
351
+ image_names = []
352
+ if imf is None:
353
+ imf = ""
354
+
355
+ folders = []
356
+ if look_one_level_down:
357
+ folders = natsorted(glob.glob(os.path.join(folder, "*/")))
358
+ folders.append(folder)
359
+ exts = [".png", ".jpg", ".jpeg", ".tif", ".tiff", ".flex", ".dax", ".nd2", ".nrrd"]
360
+ l0 = 0
361
+ al = 0
362
+ for folder in folders:
363
+ all_files = glob.glob(folder + "/*")
364
+ al += len(all_files)
365
+ for ext in exts:
366
+ image_names.extend(glob.glob(folder + f"/*{imf}{ext}"))
367
+ image_names.extend(glob.glob(folder + f"/*{imf}{ext.upper()}"))
368
+ l0 += len(image_names)
369
+
370
+ # return error if no files found
371
+ if al == 0:
372
+ raise ValueError("ERROR: no files in --dir folder ")
373
+ elif l0 == 0:
374
+ raise ValueError(
375
+ "ERROR: no images in --dir folder with extensions .png, .jpg, .jpeg, .tif, .tiff, .flex"
376
+ )
377
+
378
+ image_names = natsorted(image_names)
379
+ imn = []
380
+ for im in image_names:
381
+ imfile = os.path.splitext(im)[0]
382
+ igood = all([(len(imfile) > len(mask_filter) and
383
+ imfile[-len(mask_filter):] != mask_filter) or
384
+ len(imfile) <= len(mask_filter) for mask_filter in mask_filters])
385
+ if len(imf) > 0:
386
+ igood &= imfile[-len(imf):] == imf
387
+ if igood:
388
+ imn.append(im)
389
+
390
+ image_names = imn
391
+
392
+ # remove duplicates
393
+ image_names = [*set(image_names)]
394
+ image_names = natsorted(image_names)
395
+
396
+ if len(image_names) == 0:
397
+ raise ValueError(
398
+ "ERROR: no images in --dir folder without _masks or _flows or _cellprob ending")
399
+
400
+ return image_names
401
+
402
+ def get_label_files(image_names, mask_filter, imf=None):
403
+ """
404
+ Get the label files corresponding to the given image names and mask filter.
405
+
406
+ Args:
407
+ image_names (list): List of image names.
408
+ mask_filter (str): Mask filter to be applied.
409
+ imf (str, optional): Image file extension. Defaults to None.
410
+
411
+ Returns:
412
+ tuple: A tuple containing the label file names and flow file names (if present).
413
+ """
414
+ nimg = len(image_names)
415
+ label_names0 = [os.path.splitext(image_names[n])[0] for n in range(nimg)]
416
+
417
+ if imf is not None and len(imf) > 0:
418
+ label_names = [label_names0[n][:-len(imf)] for n in range(nimg)]
419
+ else:
420
+ label_names = label_names0
421
+
422
+ # check for flows
423
+ if os.path.exists(label_names0[0] + "_flows.tif"):
424
+ flow_names = [label_names0[n] + "_flows.tif" for n in range(nimg)]
425
+ else:
426
+ flow_names = [label_names[n] + "_flows.tif" for n in range(nimg)]
427
+ if not all([os.path.exists(flow) for flow in flow_names]):
428
+ io_logger.info(
429
+ "not all flows are present, running flow generation for all images")
430
+ flow_names = None
431
+
432
+ # check for masks
433
+ if mask_filter == "_seg.npy":
434
+ label_names = [label_names[n] + mask_filter for n in range(nimg)]
435
+ return label_names, None
436
+
437
+ if os.path.exists(label_names[0] + mask_filter + ".tif"):
438
+ label_names = [label_names[n] + mask_filter + ".tif" for n in range(nimg)]
439
+ elif os.path.exists(label_names[0] + mask_filter + ".tiff"):
440
+ label_names = [label_names[n] + mask_filter + ".tiff" for n in range(nimg)]
441
+ elif os.path.exists(label_names[0] + mask_filter + ".png"):
442
+ label_names = [label_names[n] + mask_filter + ".png" for n in range(nimg)]
443
+ # TODO, allow _seg.npy
444
+ #elif os.path.exists(label_names[0] + "_seg.npy"):
445
+ # io_logger.info("labels found as _seg.npy files, converting to tif")
446
+ else:
447
+ if not flow_names:
448
+ raise ValueError("labels not provided with correct --mask_filter")
449
+ else:
450
+ label_names = None
451
+ if not all([os.path.exists(label) for label in label_names]):
452
+ if not flow_names:
453
+ raise ValueError(
454
+ "labels not provided for all images in train and/or test set")
455
+ else:
456
+ label_names = None
457
+
458
+ return label_names, flow_names
459
+
460
+
461
+ def load_images_labels(tdir, mask_filter="_masks", image_filter=None,
462
+ look_one_level_down=False):
463
+ """
464
+ Loads images and corresponding labels from a directory.
465
+
466
+ Args:
467
+ tdir (str): The directory path.
468
+ mask_filter (str, optional): The filter for mask files. Defaults to "_masks".
469
+ image_filter (str, optional): The filter for image files. Defaults to None.
470
+ look_one_level_down (bool, optional): Whether to look for files one level down. Defaults to False.
471
+
472
+ Returns:
473
+ tuple: A tuple containing a list of images, a list of labels, and a list of image names.
474
+ """
475
+ image_names = get_image_files(tdir, mask_filter, image_filter, look_one_level_down)
476
+ nimg = len(image_names)
477
+
478
+ # training data
479
+ label_names, flow_names = get_label_files(image_names, mask_filter,
480
+ imf=image_filter)
481
+
482
+ images = []
483
+ labels = []
484
+ k = 0
485
+ for n in range(nimg):
486
+ if (os.path.isfile(label_names[n]) or
487
+ (flow_names is not None and os.path.isfile(flow_names[0]))):
488
+ image = imread(image_names[n])
489
+ if label_names is not None:
490
+ label = imread(label_names[n])
491
+ if flow_names is not None:
492
+ flow = imread(flow_names[n])
493
+ if flow.shape[0] < 4:
494
+ label = np.concatenate((label[np.newaxis, :, :], flow), axis=0)
495
+ else:
496
+ label = flow
497
+ images.append(image)
498
+ labels.append(label)
499
+ k += 1
500
+ io_logger.info(f"{k} / {nimg} images in {tdir} folder have labels")
501
+ return images, labels, image_names
502
+
503
+ def load_train_test_data(train_dir, test_dir=None, image_filter=None,
504
+ mask_filter="_masks", look_one_level_down=False):
505
+ """
506
+ Loads training and testing data for a Cellpose model.
507
+
508
+ Args:
509
+ train_dir (str): The directory path containing the training data.
510
+ test_dir (str, optional): The directory path containing the testing data. Defaults to None.
511
+ image_filter (str, optional): The filter for selecting image files. Defaults to None.
512
+ mask_filter (str, optional): The filter for selecting mask files. Defaults to "_masks".
513
+ look_one_level_down (bool, optional): Whether to look for data in subdirectories of train_dir and test_dir. Defaults to False.
514
+
515
+ Returns:
516
+ images, labels, image_names, test_images, test_labels, test_image_names
517
+
518
+ """
519
+ images, labels, image_names = load_images_labels(train_dir, mask_filter,
520
+ image_filter, look_one_level_down)
521
+ # testing data
522
+ test_images, test_labels, test_image_names = None, None, None
523
+ if test_dir is not None:
524
+ test_images, test_labels, test_image_names = load_images_labels(
525
+ test_dir, mask_filter, image_filter, look_one_level_down)
526
+
527
+ return images, labels, image_names, test_images, test_labels, test_image_names
528
+
529
+
530
+ def masks_flows_to_seg(images, masks, flows, file_names,
531
+ channels=None,
532
+ imgs_restore=None, restore_type=None, ratio=1.):
533
+ """Save output of model eval to be loaded in GUI.
534
+
535
+ Can be list output (run on multiple images) or single output (run on single image).
536
+
537
+ Saved to file_names[k]+"_seg.npy".
538
+
539
+ Args:
540
+ images (list): Images input into cellpose.
541
+ masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
542
+ flows (list): Flows output from Cellpose.eval.
543
+ file_names (list, str): Names of files of images.
544
+ diams (float array): Diameters used to run Cellpose. Defaults to 30. TODO: remove this
545
+ channels (list, int, optional): Channels used to run Cellpose. Defaults to None.
546
+
547
+ Returns:
548
+ None
549
+ """
550
+
551
+ if channels is None:
552
+ channels = [0, 0]
553
+
554
+ if isinstance(masks, list):
555
+ if imgs_restore is None:
556
+ imgs_restore = [None] * len(masks)
557
+ if isinstance(file_names, str):
558
+ file_names = [file_names] * len(masks)
559
+ for k, [image, mask, flow,
560
+ # diam,
561
+ file_name, img_restore
562
+ ] in enumerate(zip(images, masks, flows,
563
+ # diams,
564
+ file_names,
565
+ imgs_restore)):
566
+ channels_img = channels
567
+ if channels_img is not None and len(channels) > 2:
568
+ channels_img = channels[k]
569
+ masks_flows_to_seg(image, mask, flow, file_name,
570
+ # diams=diam,
571
+ channels=channels_img, imgs_restore=img_restore,
572
+ restore_type=restore_type, ratio=ratio)
573
+ return
574
+
575
+ if len(channels) == 1:
576
+ channels = channels[0]
577
+
578
+ flowi = []
579
+ if flows[0].ndim == 3:
580
+ Ly, Lx = masks.shape[-2:]
581
+ flowi.append(
582
+ cv2.resize(flows[0], (Lx, Ly), interpolation=cv2.INTER_NEAREST)[np.newaxis,
583
+ ...])
584
+ else:
585
+ flowi.append(flows[0])
586
+
587
+ if flows[0].ndim == 3:
588
+ cellprob = (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(
589
+ np.uint8)
590
+ cellprob = cv2.resize(cellprob, (Lx, Ly), interpolation=cv2.INTER_NEAREST)
591
+ flowi.append(cellprob[np.newaxis, ...])
592
+ flowi.append(np.zeros(flows[0].shape, dtype=np.uint8))
593
+ flowi[-1] = flowi[-1][np.newaxis, ...]
594
+ else:
595
+ flowi.append(
596
+ (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(np.uint8))
597
+ flowi.append((flows[1][0] / 10 * 127 + 127).astype(np.uint8))
598
+ if len(flows) > 2:
599
+ if len(flows) > 3:
600
+ flowi.append(flows[3])
601
+ else:
602
+ flowi.append([])
603
+ flowi.append(np.concatenate((flows[1], flows[2][np.newaxis, ...]), axis=0))
604
+ outlines = masks * utils.masks_to_outlines(masks)
605
+ base = os.path.splitext(file_names)[0]
606
+
607
+ dat = {
608
+ "outlines":
609
+ outlines.astype(np.uint16) if outlines.max() < 2**16 -
610
+ 1 else outlines.astype(np.uint32),
611
+ "masks":
612
+ masks.astype(np.uint16) if outlines.max() < 2**16 -
613
+ 1 else masks.astype(np.uint32),
614
+ "chan_choose":
615
+ channels,
616
+ "ismanual":
617
+ np.zeros(masks.max(), bool),
618
+ "filename":
619
+ file_names,
620
+ "flows":
621
+ flowi,
622
+ "diameter":
623
+ np.nan
624
+ }
625
+ if restore_type is not None and imgs_restore is not None:
626
+ dat["restore"] = restore_type
627
+ dat["ratio"] = ratio
628
+ dat["img_restore"] = imgs_restore
629
+
630
+ np.save(base + "_seg.npy", dat)
631
+
632
+ def save_to_png(images, masks, flows, file_names):
633
+ """ deprecated (runs io.save_masks with png=True)
634
+
635
+ does not work for 3D images
636
+
637
+ """
638
+ save_masks(images, masks, flows, file_names, png=True)
639
+
640
+
641
+ def save_rois(masks, file_name, multiprocessing=None):
642
+ """ save masks to .roi files in .zip archive for ImageJ/Fiji
643
+
644
+ Args:
645
+ masks (np.ndarray): masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels
646
+ file_name (str): name to save the .zip file to
647
+
648
+ Returns:
649
+ None
650
+ """
651
+ outlines = utils.outlines_list(masks, multiprocessing=multiprocessing)
652
+ nonempty_outlines = [outline for outline in outlines if len(outline)!=0]
653
+ if len(outlines)!=len(nonempty_outlines):
654
+ print(f"empty outlines found, saving {len(nonempty_outlines)} ImageJ ROIs to .zip archive.")
655
+ rois = [ImagejRoi.frompoints(outline) for outline in nonempty_outlines]
656
+ file_name = os.path.splitext(file_name)[0] + '_rois.zip'
657
+
658
+
659
+ # Delete file if it exists; the roifile lib appends to existing zip files.
660
+ # If the user removed a mask it will still be in the zip file
661
+ if os.path.exists(file_name):
662
+ os.remove(file_name)
663
+
664
+ roiwrite(file_name, rois)
665
+
666
+
667
+ def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[0, 0],
668
+ suffix="_cp_masks", save_flows=False, save_outlines=False, dir_above=False,
669
+ in_folders=False, savedir=None, save_txt=False, save_mpl=False):
670
+ """ Save masks + nicely plotted segmentation image to png and/or tiff.
671
+
672
+ Can save masks, flows to different directories, if in_folders is True.
673
+
674
+ If png, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.png".
675
+
676
+ If tif, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.tif".
677
+
678
+ If png and matplotlib installed, full segmentation figure is saved to file_names[k]+"_cp.png".
679
+
680
+ Only tif option works for 3D data, and only tif option works for empty masks.
681
+
682
+ Args:
683
+ images (list): Images input into cellpose.
684
+ masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
685
+ flows (list): Flows output from Cellpose.eval.
686
+ file_names (list, str): Names of files of images.
687
+ png (bool, optional): Save masks to PNG. Defaults to True.
688
+ tif (bool, optional): Save masks to TIF. Defaults to False.
689
+ channels (list, int, optional): Channels used to run Cellpose. Defaults to [0,0].
690
+ suffix (str, optional): Add name to saved masks. Defaults to "_cp_masks".
691
+ save_flows (bool, optional): Save flows output from Cellpose.eval. Defaults to False.
692
+ save_outlines (bool, optional): Save outlines of masks. Defaults to False.
693
+ dir_above (bool, optional): Save masks/flows in directory above. Defaults to False.
694
+ in_folders (bool, optional): Save masks/flows in separate folders. Defaults to False.
695
+ savedir (str, optional): Absolute path where images will be saved. If None, saves to image directory. Defaults to None.
696
+ save_txt (bool, optional): Save masks as list of outlines for ImageJ. Defaults to False.
697
+ save_mpl (bool, optional): If True, saves a matplotlib figure of the original image/segmentation/flows. Does not work for 3D.
698
+ This takes a long time for large images. Defaults to False.
699
+
700
+ Returns:
701
+ None
702
+ """
703
+
704
+ if isinstance(masks, list):
705
+ for image, mask, flow, file_name in zip(images, masks, flows, file_names):
706
+ save_masks(image, mask, flow, file_name, png=png, tif=tif, suffix=suffix,
707
+ dir_above=dir_above, save_flows=save_flows,
708
+ save_outlines=save_outlines, savedir=savedir, save_txt=save_txt,
709
+ in_folders=in_folders, save_mpl=save_mpl)
710
+ return
711
+
712
+ if masks.ndim > 2 and not tif:
713
+ raise ValueError("cannot save 3D outputs as PNG, use tif option instead")
714
+
715
+ if masks.max() == 0:
716
+ io_logger.warning("no masks found, will not save PNG or outlines")
717
+ if not tif:
718
+ return
719
+ else:
720
+ png = False
721
+ save_outlines = False
722
+ save_flows = False
723
+ save_txt = False
724
+
725
+ if savedir is None:
726
+ if dir_above:
727
+ savedir = Path(file_names).parent.parent.absolute(
728
+ ) #go up a level to save in its own folder
729
+ else:
730
+ savedir = Path(file_names).parent.absolute()
731
+
732
+ check_dir(savedir)
733
+
734
+ basename = os.path.splitext(os.path.basename(file_names))[0]
735
+ if in_folders:
736
+ maskdir = os.path.join(savedir, "masks")
737
+ outlinedir = os.path.join(savedir, "outlines")
738
+ txtdir = os.path.join(savedir, "txt_outlines")
739
+ flowdir = os.path.join(savedir, "flows")
740
+ else:
741
+ maskdir = savedir
742
+ outlinedir = savedir
743
+ txtdir = savedir
744
+ flowdir = savedir
745
+
746
+ check_dir(maskdir)
747
+
748
+ exts = []
749
+ if masks.ndim > 2:
750
+ png = False
751
+ tif = True
752
+ if png:
753
+ if masks.max() < 2**16:
754
+ masks = masks.astype(np.uint16)
755
+ exts.append(".png")
756
+ else:
757
+ png = False
758
+ tif = True
759
+ io_logger.warning(
760
+ "found more than 65535 masks in each image, cannot save PNG, saving as TIF"
761
+ )
762
+ if tif:
763
+ exts.append(".tif")
764
+
765
+ # save masks
766
+ with warnings.catch_warnings():
767
+ warnings.simplefilter("ignore")
768
+ for ext in exts:
769
+ imsave(os.path.join(maskdir, basename + suffix + ext), masks)
770
+
771
+ if save_mpl and png and MATPLOTLIB and not min(images.shape) > 3:
772
+ # Make and save original/segmentation/flows image
773
+
774
+ img = images.copy()
775
+ if img.ndim < 3:
776
+ img = img[:, :, np.newaxis]
777
+ elif img.shape[0] < 8:
778
+ np.transpose(img, (1, 2, 0))
779
+
780
+ fig = plt.figure(figsize=(12, 3))
781
+ plot.show_segmentation(fig, img, masks, flows[0])
782
+ fig.savefig(os.path.join(savedir, basename + "_cp_output" + suffix + ".png"),
783
+ dpi=300)
784
+ plt.close(fig)
785
+
786
+ # ImageJ txt outline files
787
+ if masks.ndim < 3 and save_txt:
788
+ check_dir(txtdir)
789
+ outlines = utils.outlines_list(masks)
790
+ outlines_to_text(os.path.join(txtdir, basename), outlines)
791
+
792
+ # RGB outline images
793
+ if masks.ndim < 3 and save_outlines:
794
+ check_dir(outlinedir)
795
+ outlines = utils.masks_to_outlines(masks)
796
+ outX, outY = np.nonzero(outlines)
797
+ img0 = transforms.normalize99(images)
798
+ if img0.shape[0] < 4:
799
+ img0 = np.transpose(img0, (1, 2, 0))
800
+ if img0.shape[-1] < 3 or img0.ndim < 3:
801
+ img0 = plot.image_to_rgb(img0, channels=channels)
802
+ else:
803
+ if img0.max() <= 50.0:
804
+ img0 = np.uint8(np.clip(img0 * 255, 0, 1))
805
+ imgout = img0.copy()
806
+ imgout[outX, outY] = np.array([255, 0, 0]) #pure red
807
+ imsave(os.path.join(outlinedir, basename + "_outlines" + suffix + ".png"),
808
+ imgout)
809
+
810
+ # save RGB flow picture
811
+ if masks.ndim < 3 and save_flows:
812
+ check_dir(flowdir)
813
+ imsave(os.path.join(flowdir, basename + "_flows" + suffix + ".tif"),
814
+ (flows[0] * (2**16 - 1)).astype(np.uint16))
815
+ #save full flow data
816
+ imsave(os.path.join(flowdir, basename + '_dP' + suffix + '.tif'), flows[1])
models/seg_post_model/cellpose/metrics.py ADDED
@@ -0,0 +1,205 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import numpy as np
5
+ from . import utils
6
+ from scipy.optimize import linear_sum_assignment
7
+ from scipy.ndimage import convolve
8
+ from scipy.sparse import csr_matrix
9
+
10
+
11
+ def mask_ious(masks_true, masks_pred):
12
+ """Return best-matched masks."""
13
+ iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:]
14
+ n_min = min(iou.shape[0], iou.shape[1])
15
+ costs = -(iou >= 0.5).astype(float) - iou / (2 * n_min)
16
+ true_ind, pred_ind = linear_sum_assignment(costs)
17
+ iout = np.zeros(masks_true.max())
18
+ iout[true_ind] = iou[true_ind, pred_ind]
19
+ preds = np.zeros(masks_true.max(), "int")
20
+ preds[true_ind] = pred_ind + 1
21
+ return iout, preds
22
+
23
+
24
+ def boundary_scores(masks_true, masks_pred, scales):
25
+ """
26
+ Calculate boundary precision, recall, and F-score.
27
+
28
+ Args:
29
+ masks_true (list): List of true masks.
30
+ masks_pred (list): List of predicted masks.
31
+ scales (list): List of scales.
32
+
33
+ Returns:
34
+ tuple: A tuple containing precision, recall, and F-score arrays.
35
+ """
36
+ diams = [utils.diameters(lbl)[0] for lbl in masks_true]
37
+ precision = np.zeros((len(scales), len(masks_true)))
38
+ recall = np.zeros((len(scales), len(masks_true)))
39
+ fscore = np.zeros((len(scales), len(masks_true)))
40
+ for j, scale in enumerate(scales):
41
+ for n in range(len(masks_true)):
42
+ diam = max(1, scale * diams[n])
43
+ rs, ys, xs = utils.circleMask([int(np.ceil(diam)), int(np.ceil(diam))])
44
+ filt = (rs <= diam).astype(np.float32)
45
+ otrue = utils.masks_to_outlines(masks_true[n])
46
+ otrue = convolve(otrue, filt)
47
+ opred = utils.masks_to_outlines(masks_pred[n])
48
+ opred = convolve(opred, filt)
49
+ tp = np.logical_and(otrue == 1, opred == 1).sum()
50
+ fp = np.logical_and(otrue == 0, opred == 1).sum()
51
+ fn = np.logical_and(otrue == 1, opred == 0).sum()
52
+ precision[j, n] = tp / (tp + fp)
53
+ recall[j, n] = tp / (tp + fn)
54
+ fscore[j] = 2 * precision[j] * recall[j] / (precision[j] + recall[j])
55
+ return precision, recall, fscore
56
+
57
+
58
+ def aggregated_jaccard_index(masks_true, masks_pred):
59
+ """
60
+ AJI = intersection of all matched masks / union of all masks
61
+
62
+ Args:
63
+ masks_true (list of np.ndarrays (int) or np.ndarray (int)):
64
+ where 0=NO masks; 1,2... are mask labels
65
+ masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
66
+ np.ndarray (int) where 0=NO masks; 1,2... are mask labels
67
+
68
+ Returns:
69
+ aji (float): aggregated jaccard index for each set of masks
70
+ """
71
+ aji = np.zeros(len(masks_true))
72
+ for n in range(len(masks_true)):
73
+ iout, preds = mask_ious(masks_true[n], masks_pred[n])
74
+ inds = np.arange(0, masks_true[n].max(), 1, int)
75
+ overlap = _label_overlap(masks_true[n], masks_pred[n])
76
+ union = np.logical_or(masks_true[n] > 0, masks_pred[n] > 0).sum()
77
+ overlap = overlap[inds[preds > 0] + 1, preds[preds > 0].astype(int)]
78
+ aji[n] = overlap.sum() / union
79
+ return aji
80
+
81
+
82
+ def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
83
+ """
84
+ Average precision estimation: AP = TP / (TP + FP + FN)
85
+
86
+ This function is based heavily on the *fast* stardist matching functions
87
+ (https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py)
88
+
89
+ Args:
90
+ masks_true (list of np.ndarrays (int) or np.ndarray (int)):
91
+ where 0=NO masks; 1,2... are mask labels
92
+ masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
93
+ np.ndarray (int) where 0=NO masks; 1,2... are mask labels
94
+
95
+ Returns:
96
+ ap (array [len(masks_true) x len(threshold)]):
97
+ average precision at thresholds
98
+ tp (array [len(masks_true) x len(threshold)]):
99
+ number of true positives at thresholds
100
+ fp (array [len(masks_true) x len(threshold)]):
101
+ number of false positives at thresholds
102
+ fn (array [len(masks_true) x len(threshold)]):
103
+ number of false negatives at thresholds
104
+ """
105
+ not_list = False
106
+ if not isinstance(masks_true, list):
107
+ masks_true = [masks_true]
108
+ masks_pred = [masks_pred]
109
+ not_list = True
110
+ if not isinstance(threshold, list) and not isinstance(threshold, np.ndarray):
111
+ threshold = [threshold]
112
+
113
+ if len(masks_true) != len(masks_pred):
114
+ raise ValueError(
115
+ "metrics.average_precision requires len(masks_true)==len(masks_pred)")
116
+
117
+ ap = np.zeros((len(masks_true), len(threshold)), np.float32)
118
+ tp = np.zeros((len(masks_true), len(threshold)), np.float32)
119
+ fp = np.zeros((len(masks_true), len(threshold)), np.float32)
120
+ fn = np.zeros((len(masks_true), len(threshold)), np.float32)
121
+ n_true = np.array([len(np.unique(mt)) - 1 for mt in masks_true])
122
+ n_pred = np.array([len(np.unique(mp)) - 1 for mp in masks_pred])
123
+
124
+ for n in range(len(masks_true)):
125
+ #_,mt = np.reshape(np.unique(masks_true[n], return_index=True), masks_pred[n].shape)
126
+ if n_pred[n] > 0:
127
+ iou = _intersection_over_union(masks_true[n], masks_pred[n])[1:, 1:]
128
+ for k, th in enumerate(threshold):
129
+ tp[n, k] = _true_positive(iou, th)
130
+ fp[n] = n_pred[n] - tp[n]
131
+ fn[n] = n_true[n] - tp[n]
132
+ ap[n] = tp[n] / (tp[n] + fp[n] + fn[n])
133
+
134
+ if not_list:
135
+ ap, tp, fp, fn = ap[0], tp[0], fp[0], fn[0]
136
+ return ap, tp, fp, fn
137
+
138
+
139
+ def _intersection_over_union(masks_true, masks_pred):
140
+ """Calculate the intersection over union of all mask pairs.
141
+
142
+ Parameters:
143
+ masks_true (np.ndarray, int): Ground truth masks, where 0=NO masks; 1,2... are mask labels.
144
+ masks_pred (np.ndarray, int): Predicted masks, where 0=NO masks; 1,2... are mask labels.
145
+
146
+ Returns:
147
+ iou (np.ndarray, float): Matrix of IOU pairs of size [x.max()+1, y.max()+1].
148
+
149
+ How it works:
150
+ The overlap matrix is a lookup table of the area of intersection
151
+ between each set of labels (true and predicted). The true labels
152
+ are taken to be along axis 0, and the predicted labels are taken
153
+ to be along axis 1. The sum of the overlaps along axis 0 is thus
154
+ an array giving the total overlap of the true labels with each of
155
+ the predicted labels, and likewise the sum over axis 1 is the
156
+ total overlap of the predicted labels with each of the true labels.
157
+ Because the label 0 (background) is included, this sum is guaranteed
158
+ to reconstruct the total area of each label. Adding this row and
159
+ column vectors gives a 2D array with the areas of every label pair
160
+ added together. This is equivalent to the union of the label areas
161
+ except for the duplicated overlap area, so the overlap matrix is
162
+ subtracted to find the union matrix.
163
+ """
164
+ if masks_true.size != masks_pred.size:
165
+ raise ValueError(f"masks_true.size {masks_true.shape} != masks_pred.size {masks_pred.shape}")
166
+ overlap = csr_matrix((np.ones((masks_true.size,), "int"),
167
+ (masks_true.flatten(), masks_pred.flatten())),
168
+ shape=(masks_true.max()+1, masks_pred.max()+1))
169
+ overlap = overlap.toarray()
170
+ n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
171
+ n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
172
+ iou = overlap / (n_pixels_pred + n_pixels_true - overlap)
173
+ iou[np.isnan(iou)] = 0.0
174
+ return iou
175
+
176
+
177
+ def _true_positive(iou, th):
178
+ """Calculate the true positive at threshold th.
179
+
180
+ Args:
181
+ iou (float, np.ndarray): Array of IOU pairs.
182
+ th (float): Threshold on IOU for positive label.
183
+
184
+ Returns:
185
+ tp (float): Number of true positives at threshold.
186
+
187
+ How it works:
188
+ (1) Find minimum number of masks.
189
+ (2) Define cost matrix; for a given threshold, each element is negative
190
+ the higher the IoU is (perfect IoU is 1, worst is 0). The second term
191
+ gets more negative with higher IoU, but less negative with greater
192
+ n_min (but that's a constant...).
193
+ (3) Solve the linear sum assignment problem. The costs array defines the cost
194
+ of matching a true label with a predicted label, so the problem is to
195
+ find the set of pairings that minimizes this cost. The scipy.optimize
196
+ function gives the ordered lists of corresponding true and predicted labels.
197
+ (4) Extract the IoUs from these pairings and then threshold to get a boolean array
198
+ whose sum is the number of true positives that is returned.
199
+ """
200
+ n_min = min(iou.shape[0], iou.shape[1])
201
+ costs = -(iou >= th).astype(float) - iou / (2 * n_min)
202
+ true_ind, pred_ind = linear_sum_assignment(costs)
203
+ match_ok = iou[true_ind, pred_ind] >= th
204
+ tp = match_ok.sum()
205
+ return tp
models/seg_post_model/cellpose/models.py ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
3
+ """
4
+
5
+ import os, time
6
+ from pathlib import Path
7
+ import numpy as np
8
+ from tqdm import trange
9
+ import torch
10
+ from scipy.ndimage import gaussian_filter
11
+ import gc
12
+ import cv2
13
+
14
+ import logging
15
+
16
+ models_logger = logging.getLogger(__name__)
17
+
18
+ from . import transforms, dynamics, utils, plot
19
+ from .vit_sam import Transformer
20
+ from .core import assign_device, run_net, run_3D
21
+
22
+ _CPSAM_MODEL_URL = "https://huggingface.co/mouseland/cellpose-sam/resolve/main/cpsam"
23
+ _MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH")
24
+ # _MODEL_DIR_DEFAULT = Path.home().joinpath(".cellpose", "models")
25
+ _MODEL_DIR_DEFAULT = Path("/media/data1/huix/seg/cellpose_models")
26
+ MODEL_DIR = Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT
27
+
28
+ MODEL_NAMES = ["cpsam"]
29
+
30
+ MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt"))
31
+
32
+ normalize_default = {
33
+ "lowhigh": None,
34
+ "percentile": None,
35
+ "normalize": True,
36
+ "norm3D": True,
37
+ "sharpen_radius": 0,
38
+ "smooth_radius": 0,
39
+ "tile_norm_blocksize": 0,
40
+ "tile_norm_smooth3D": 1,
41
+ "invert": False
42
+ }
43
+
44
+
45
+ # def model_path(model_type, model_index=0):
46
+ # return cache_CPSAM_model_path()
47
+
48
+
49
+ # def cache_CPSAM_model_path():
50
+ # MODEL_DIR.mkdir(parents=True, exist_ok=True)
51
+ # cached_file = os.fspath(MODEL_DIR.joinpath('cpsam'))
52
+ # if not os.path.exists(cached_file):
53
+ # models_logger.info('Downloading: "{}" to {}\n'.format(_CPSAM_MODEL_URL, cached_file))
54
+ # utils.download_url_to_file(_CPSAM_MODEL_URL, cached_file, progress=True)
55
+ # return cached_file
56
+
57
+
58
+ def get_user_models():
59
+ model_strings = []
60
+ if os.path.exists(MODEL_LIST_PATH):
61
+ with open(MODEL_LIST_PATH, "r") as textfile:
62
+ lines = [line.rstrip() for line in textfile]
63
+ if len(lines) > 0:
64
+ model_strings.extend(lines)
65
+ return model_strings
66
+
67
+
68
+ class CellposeModel():
69
+ """
70
+ Class representing a Cellpose model.
71
+
72
+ Attributes:
73
+ diam_mean (float): Mean "diameter" value for the model.
74
+ builtin (bool): Whether the model is a built-in model or not.
75
+ device (torch device): Device used for model running / training.
76
+ nclasses (int): Number of classes in the model.
77
+ nbase (list): List of base values for the model.
78
+ net (CPnet): Cellpose network.
79
+ pretrained_model (str): Path to pretrained cellpose model.
80
+ pretrained_model_ortho (str): Path or model_name for pretrained cellpose model for ortho views in 3D.
81
+ backbone (str): Type of network ("default" is the standard res-unet, "transformer" for the segformer).
82
+
83
+ Methods:
84
+ __init__(self, gpu=False, pretrained_model=False, model_type=None, diam_mean=30., device=None):
85
+ Initialize the CellposeModel.
86
+
87
+ eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, progress=None):
88
+ Segment list of images x, or 4D array - Z x C x Y x X.
89
+
90
+ """
91
+
92
+ def __init__(self, gpu=False, pretrained_model="", model_type=None,
93
+ diam_mean=None, device=None, nchan=None, use_bfloat16=True, vit_checkpoint=None):
94
+ """
95
+ Initialize the CellposeModel.
96
+
97
+ Parameters:
98
+ gpu (bool, optional): Whether or not to save model to GPU, will check if GPU available.
99
+ pretrained_model (str or list of strings, optional): Full path to pretrained cellpose model(s), if None or False, no model loaded.
100
+ model_type (str, optional): Any model that is available in the GUI, use name in GUI e.g. "livecell" (can be user-trained or model zoo).
101
+ diam_mean (float, optional): Mean "diameter", 30. is built-in value for "cyto" model; 17. is built-in value for "nuclei" model; if saved in custom model file (cellpose>=2.0) then it will be loaded automatically and overwrite this value.
102
+ device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
103
+ use_bfloat16 (bool, optional): Use 16bit float precision instead of 32bit for model weights. Default to 16bit (True).
104
+ """
105
+ # if diam_mean is not None:
106
+ # models_logger.warning(
107
+ # "diam_mean argument are not used in v4.0.1+. Ignoring this argument..."
108
+ # )
109
+ # if model_type is not None:
110
+ # models_logger.warning(
111
+ # "model_type argument is not used in v4.0.1+. Ignoring this argument..."
112
+ # )
113
+ # if nchan is not None:
114
+ # models_logger.warning("nchan argument is deprecated in v4.0.1+. Ignoring this argument")
115
+
116
+ ### assign model device
117
+ self.device = assign_device(gpu=gpu)[0] if device is None else device
118
+ if torch.cuda.is_available():
119
+ device_gpu = self.device.type == "cuda"
120
+ elif torch.backends.mps.is_available():
121
+ device_gpu = self.device.type == "mps"
122
+ else:
123
+ device_gpu = False
124
+ self.gpu = device_gpu
125
+
126
+ if pretrained_model is None:
127
+ # raise ValueError("Must specify a pretrained model, training from scratch is not implemented")
128
+ pretrained_model = ""
129
+
130
+ ### create neural network
131
+ if pretrained_model and not os.path.exists(pretrained_model):
132
+ # check if pretrained model is in the models directory
133
+ model_strings = get_user_models()
134
+ all_models = MODEL_NAMES.copy()
135
+ all_models.extend(model_strings)
136
+ if pretrained_model in all_models:
137
+ pretrained_model = os.path.join(MODEL_DIR, pretrained_model)
138
+ else:
139
+ pretrained_model = os.path.join(MODEL_DIR, "cpsam")
140
+ models_logger.warning(
141
+ f"pretrained model {pretrained_model} not found, using default model"
142
+ )
143
+
144
+ self.pretrained_model = pretrained_model
145
+ dtype = torch.bfloat16 if use_bfloat16 else torch.float32
146
+ self.net = Transformer(dtype=dtype, checkpoint=vit_checkpoint).to(self.device)
147
+
148
+ if os.path.exists(self.pretrained_model):
149
+ models_logger.info(f">>>> loading model {self.pretrained_model}")
150
+ self.net.load_model(self.pretrained_model, device=self.device)
151
+ # else:
152
+ # try:
153
+ # if os.path.split(self.pretrained_model)[-1] != 'cpsam':
154
+ # raise FileNotFoundError('model file not recognized')
155
+ # cache_CPSAM_model_path()
156
+ # self.net.load_model(self.pretrained_model, device=self.device)
157
+ # except:
158
+ # print("ViT not initialized")
159
+
160
+
161
+ def eval(self, x, feat=None, batch_size=8, resample=True, channels=None, channel_axis=None,
162
+ z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
163
+ flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None,
164
+ flow3D_smooth=0, stitch_threshold=0.0,
165
+ min_size=15, max_size_fraction=0.4, niter=None,
166
+ augment=False, tile_overlap=0.1, bsize=256,
167
+ compute_masks=True, progress=None):
168
+
169
+
170
+ # if rescale is not None:
171
+ # models_logger.warning("rescaling deprecated in v4.0.1+")
172
+ # if channels is not None:
173
+ # models_logger.warning("channels deprecated in v4.0.1+. If data contain more than 3 channels, only the first 3 channels will be used")
174
+
175
+ if isinstance(x, list) or x.squeeze().ndim == 5:
176
+ self.timing = []
177
+ masks, styles, flows = [], [], []
178
+ tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
179
+ nimg = len(x)
180
+ iterator = trange(nimg, file=tqdm_out,
181
+ mininterval=30) if nimg > 1 else range(nimg)
182
+ for i in iterator:
183
+ tic = time.time()
184
+ maski, flowi, stylei = self.eval(
185
+ x[i],
186
+ feat=None if feat is None else feat[i],
187
+ batch_size=batch_size,
188
+ channel_axis=channel_axis,
189
+ z_axis=z_axis,
190
+ normalize=normalize,
191
+ invert=invert,
192
+ diameter=diameter[i] if isinstance(diameter, list) or
193
+ isinstance(diameter, np.ndarray) else diameter,
194
+ do_3D=do_3D,
195
+ anisotropy=anisotropy,
196
+ augment=augment,
197
+ tile_overlap=tile_overlap,
198
+ bsize=bsize,
199
+ resample=resample,
200
+ flow_threshold=flow_threshold,
201
+ cellprob_threshold=cellprob_threshold,
202
+ compute_masks=compute_masks,
203
+ min_size=min_size,
204
+ max_size_fraction=max_size_fraction,
205
+ stitch_threshold=stitch_threshold,
206
+ flow3D_smooth=flow3D_smooth,
207
+ progress=progress,
208
+ niter=niter)
209
+ masks.append(maski)
210
+ flows.append(flowi)
211
+ styles.append(stylei)
212
+ self.timing.append(time.time() - tic)
213
+ return masks, flows, styles
214
+
215
+ ############# actual eval code ############
216
+ # reshape image
217
+ x = transforms.convert_image(x, channel_axis=channel_axis,
218
+ z_axis=z_axis,
219
+ do_3D=(do_3D or stitch_threshold > 0))
220
+
221
+ # Add batch dimension if not present
222
+ if x.ndim < 4:
223
+ x = x[np.newaxis, ...]
224
+ if feat is not None:
225
+ if feat.ndim < 4:
226
+ feat = feat[np.newaxis, ...]
227
+ nimg = x.shape[0]
228
+
229
+ image_scaling = None
230
+ Ly_0 = x.shape[1]
231
+ Lx_0 = x.shape[2]
232
+ Lz_0 = None
233
+ if do_3D or stitch_threshold > 0:
234
+ Lz_0 = x.shape[0]
235
+ if diameter is not None:
236
+ image_scaling = 30. / diameter
237
+ x = transforms.resize_image(x,
238
+ Ly=int(x.shape[1] * image_scaling),
239
+ Lx=int(x.shape[2] * image_scaling))
240
+ if feat is not None:
241
+ feat = transforms.resize_image(feat,
242
+ Ly=int(feat.shape[1] * image_scaling),
243
+ Lx=int(feat.shape[2] * image_scaling))
244
+
245
+
246
+ # normalize image
247
+ normalize_params = normalize_default
248
+ if isinstance(normalize, dict):
249
+ normalize_params = {**normalize_params, **normalize}
250
+ elif not isinstance(normalize, bool):
251
+ raise ValueError("normalize parameter must be a bool or a dict")
252
+ else:
253
+ normalize_params["normalize"] = normalize
254
+ normalize_params["invert"] = invert
255
+
256
+ # pre-normalize if 3D stack for stitching or do_3D
257
+ do_normalization = True if normalize_params["normalize"] else False
258
+ if nimg > 1 and do_normalization and (stitch_threshold or do_3D):
259
+ normalize_params["norm3D"] = True if do_3D else normalize_params["norm3D"]
260
+ x = transforms.normalize_img(x, **normalize_params)
261
+ do_normalization = False # do not normalize again
262
+ else:
263
+ if normalize_params["norm3D"] and nimg > 1 and do_normalization:
264
+ models_logger.warning(
265
+ "normalize_params['norm3D'] is True but do_3D is False and stitch_threshold=0, so setting to False"
266
+ )
267
+ normalize_params["norm3D"] = False
268
+ if do_normalization:
269
+ x = transforms.normalize_img(x, **normalize_params)
270
+
271
+ if feat is not None:
272
+ if feat.shape[-1] > feat.shape[1]:
273
+ # transpose feat to have channels last
274
+ feat = np.moveaxis(feat, 1, -1)
275
+
276
+ # ajust the anisotropy when diameter is specified and images are resized:
277
+ if isinstance(anisotropy, (float, int)) and image_scaling:
278
+ anisotropy = image_scaling * anisotropy
279
+
280
+ dP, cellprob, styles = self._run_net(
281
+ x,
282
+ feat=feat,
283
+ augment=augment,
284
+ batch_size=batch_size,
285
+ tile_overlap=tile_overlap,
286
+ bsize=bsize,
287
+ do_3D=do_3D,
288
+ anisotropy=anisotropy)
289
+
290
+ if do_3D:
291
+ if flow3D_smooth > 0:
292
+ models_logger.info(f"smoothing flows with sigma={flow3D_smooth}")
293
+ dP = gaussian_filter(dP, (0, flow3D_smooth, flow3D_smooth, flow3D_smooth))
294
+ torch.cuda.empty_cache()
295
+ gc.collect()
296
+
297
+ if resample:
298
+ # upsample flows before computing them:
299
+ dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
300
+ cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
301
+
302
+
303
+ if compute_masks:
304
+ niter0 = 200
305
+ niter = niter0 if niter is None or niter == 0 else niter
306
+ masks = self._compute_masks(x.shape, dP, cellprob, flow_threshold=flow_threshold,
307
+ cellprob_threshold=cellprob_threshold, min_size=min_size,
308
+ max_size_fraction=max_size_fraction, niter=niter,
309
+ stitch_threshold=stitch_threshold, do_3D=do_3D)
310
+ else:
311
+ masks = np.zeros(0) #pass back zeros if not compute_masks
312
+
313
+ masks, dP, cellprob = masks.squeeze(), dP.squeeze(), cellprob.squeeze()
314
+
315
+ # undo resizing:
316
+ if image_scaling is not None or anisotropy is not None:
317
+
318
+ dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0) # works for 2 or 3D:
319
+ cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
320
+
321
+ if do_3D:
322
+ if compute_masks:
323
+ # Rescale xy then xz:
324
+ masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
325
+ masks = masks.transpose(1, 0, 2)
326
+ masks = transforms.resize_image(masks, Ly=Lz_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
327
+ masks = masks.transpose(1, 0, 2)
328
+
329
+ else:
330
+ # 2D or 3D stitching case:
331
+ if compute_masks:
332
+ masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
333
+
334
+ return masks, [plot.dx_to_circ(dP), dP, cellprob], styles
335
+
336
+
337
+ def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
338
+ """
339
+ Resize cellprob array to specified dimensions for either 2D or 3D.
340
+
341
+ Parameters:
342
+ prob (numpy.ndarray): The cellprobs to resize, either in 2D or 3D. Returns the same ndim as provided.
343
+ to_y_size (int): The target size along the Y-axis.
344
+ to_x_size (int): The target size along the X-axis.
345
+ to_z_size (int, optional): The target size along the Z-axis. Required
346
+ for 3D cellprobs.
347
+
348
+ Returns:
349
+ numpy.ndarray: The resized cellprobs array with the same number of dimensions
350
+ as the input.
351
+
352
+ Raises:
353
+ ValueError: If the input cellprobs array does not have 3 or 4 dimensions.
354
+ """
355
+ prob_shape = prob.shape
356
+ prob = prob.squeeze()
357
+ squeeze_happened = prob.shape != prob_shape
358
+ prob_shape = np.array(prob_shape)
359
+
360
+ if prob.ndim == 2:
361
+ # 2D case:
362
+ prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True)
363
+ if squeeze_happened:
364
+ prob = np.expand_dims(prob, int(np.argwhere(prob_shape == 1))) # add back empty axis for compatibility
365
+ elif prob.ndim == 3:
366
+ # 3D case:
367
+ prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True)
368
+ prob = prob.transpose(1, 0, 2)
369
+ prob = transforms.resize_image(prob, Ly=to_z_size, Lx=to_x_size, no_channels=True)
370
+ prob = prob.transpose(1, 0, 2)
371
+ else:
372
+ raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 2 or 3, prob shape: {prob.shape}')
373
+
374
+ return prob
375
+
376
+
377
+ def _resize_gradients(self, grads: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
378
+ """
379
+ Resize gradient arrays to specified dimensions for either 2D or 3D gradients.
380
+
381
+ Parameters:
382
+ grads (np.ndarray): The gradients to resize, either in 2D or 3D. Returns the same ndim as provided.
383
+ to_y_size (int): The target size along the Y-axis.
384
+ to_x_size (int): The target size along the X-axis.
385
+ to_z_size (int, optional): The target size along the Z-axis. Required
386
+ for 3D gradients.
387
+
388
+ Returns:
389
+ numpy.ndarray: The resized gradient array with the same number of dimensions
390
+ as the input.
391
+
392
+ Raises:
393
+ ValueError: If the input gradient array does not have 3 or 4 dimensions.
394
+ """
395
+ grads_shape = grads.shape
396
+ grads = grads.squeeze()
397
+ squeeze_happened = grads.shape != grads_shape
398
+ grads_shape = np.array(grads_shape)
399
+
400
+ if grads.ndim == 3:
401
+ # 2D case, with XY flows in 2 channels:
402
+ grads = np.moveaxis(grads, 0, -1) # Put gradients last
403
+ grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False)
404
+ grads = np.moveaxis(grads, -1, 0) # Put gradients first
405
+
406
+ if squeeze_happened:
407
+ grads = np.expand_dims(grads, int(np.argwhere(grads_shape == 1))) # add back empty axis for compatibility
408
+ elif grads.ndim == 4:
409
+ # dP has gradients that can be treated as channels:
410
+ grads = grads.transpose(1, 2, 3, 0) # move gradients last:
411
+ grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False)
412
+ grads = grads.transpose(1, 0, 2, 3) # switch axes to resize again
413
+ grads = transforms.resize_image(grads, Ly=to_z_size, Lx=to_x_size, no_channels=False)
414
+ grads = grads.transpose(3, 1, 0, 2) # undo transposition
415
+ else:
416
+ raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 3 or 4, grads shape: {grads.shape}')
417
+
418
+ return grads
419
+
420
+
421
+ def _run_net(self, x, feat=None,
422
+ augment=False,
423
+ batch_size=8, tile_overlap=0.1,
424
+ bsize=224, anisotropy=1.0, do_3D=False):
425
+ """ run network on image x """
426
+ tic = time.time()
427
+ shape = x.shape
428
+ nimg = shape[0]
429
+
430
+
431
+ if do_3D:
432
+ Lz, Ly, Lx = shape[:-1]
433
+ if anisotropy is not None and anisotropy != 1.0:
434
+ models_logger.info(f"resizing 3D image with anisotropy={anisotropy}")
435
+ x = transforms.resize_image(x.transpose(1,0,2,3),
436
+ Ly=int(Lz*anisotropy),
437
+ Lx=int(Lx)).transpose(1,0,2,3)
438
+ yf, styles = run_3D(self.net, x,
439
+ batch_size=batch_size, augment=augment,
440
+ tile_overlap=tile_overlap,
441
+ bsize=bsize
442
+ )
443
+ cellprob = yf[..., -1]
444
+ dP = yf[..., :-1].transpose((3, 0, 1, 2))
445
+ else:
446
+ yf, styles = run_net(self.net, x, feat=feat, bsize=bsize, augment=augment,
447
+ batch_size=batch_size,
448
+ tile_overlap=tile_overlap,
449
+ )
450
+ cellprob = yf[..., -1]
451
+ dP = yf[..., -3:-1].transpose((3, 0, 1, 2))
452
+ if yf.shape[-1] > 3:
453
+ styles = yf[..., :-3]
454
+
455
+ styles = styles.squeeze()
456
+
457
+ net_time = time.time() - tic
458
+ if nimg > 1:
459
+ models_logger.info("network run in %2.2fs" % (net_time))
460
+
461
+ return dP, cellprob, styles
462
+
463
+ def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_threshold=0.0,
464
+ min_size=15, max_size_fraction=0.4, niter=None,
465
+ do_3D=False, stitch_threshold=0.0):
466
+ """ compute masks from flows and cell probability """
467
+ changed_device_from = None
468
+ if self.device.type == "mps" and do_3D:
469
+ models_logger.warning("MPS does not support 3D post-processing, switching to CPU")
470
+ self.device = torch.device("cpu")
471
+ changed_device_from = "mps"
472
+ Lz, Ly, Lx = shape[:3]
473
+ tic = time.time()
474
+ if do_3D:
475
+ masks = dynamics.resize_and_compute_masks(
476
+ dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
477
+ flow_threshold=flow_threshold, do_3D=do_3D,
478
+ min_size=min_size, max_size_fraction=max_size_fraction,
479
+ resize=shape[:3] if (np.array(dP.shape[-3:])!=np.array(shape[:3])).sum()
480
+ else None,
481
+ device=self.device)
482
+ else:
483
+ nimg = shape[0]
484
+ Ly0, Lx0 = cellprob[0].shape
485
+ resize = None if Ly0==Ly and Lx0==Lx else [Ly, Lx]
486
+ tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
487
+ iterator = trange(nimg, file=tqdm_out,
488
+ mininterval=30) if nimg > 1 else range(nimg)
489
+ for i in iterator:
490
+ # turn off min_size for 3D stitching
491
+ min_size0 = min_size if stitch_threshold == 0 or nimg == 1 else -1
492
+ outputs = dynamics.resize_and_compute_masks(
493
+ dP[:, i], cellprob[i],
494
+ niter=niter, cellprob_threshold=cellprob_threshold,
495
+ flow_threshold=flow_threshold, resize=resize,
496
+ min_size=min_size0, max_size_fraction=max_size_fraction,
497
+ device=self.device)
498
+ if i==0 and nimg > 1:
499
+ masks = np.zeros((nimg, shape[1], shape[2]), outputs.dtype)
500
+ if nimg > 1:
501
+ masks[i] = outputs
502
+ else:
503
+ masks = outputs
504
+
505
+ if stitch_threshold > 0 and nimg > 1:
506
+ models_logger.info(
507
+ f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks"
508
+ )
509
+ masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold)
510
+ masks = utils.fill_holes_and_remove_small_masks(
511
+ masks, min_size=min_size)
512
+ elif nimg > 1:
513
+ models_logger.warning(
514
+ "3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only"
515
+ )
516
+
517
+ flow_time = time.time() - tic
518
+ if shape[0] > 1:
519
+ models_logger.info("masks created in %2.2fs" % (flow_time))
520
+
521
+ if changed_device_from is not None:
522
+ models_logger.info("switching back to device %s" % self.device)
523
+ self.device = torch.device(changed_device_from)
524
+ return masks
models/seg_post_model/cellpose/plot.py ADDED
@@ -0,0 +1,281 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import os
5
+ import numpy as np
6
+ import cv2
7
+ from scipy.ndimage import gaussian_filter
8
+ from . import utils, io, transforms
9
+
10
+ try:
11
+ import matplotlib
12
+ MATPLOTLIB_ENABLED = True
13
+ except:
14
+ MATPLOTLIB_ENABLED = False
15
+
16
+ try:
17
+ from skimage import color
18
+ from skimage.segmentation import find_boundaries
19
+ SKIMAGE_ENABLED = True
20
+ except:
21
+ SKIMAGE_ENABLED = False
22
+
23
+
24
+ # modified to use sinebow color
25
+ def dx_to_circ(dP):
26
+ """Converts the optic flow representation to a circular color representation.
27
+
28
+ Args:
29
+ dP (ndarray): Flow field components [dy, dx].
30
+
31
+ Returns:
32
+ ndarray: The circular color representation of the optic flow.
33
+
34
+ """
35
+ mag = 255 * np.clip(transforms.normalize99(np.sqrt(np.sum(dP**2, axis=0))), 0, 1.)
36
+ angles = np.arctan2(dP[1], dP[0]) + np.pi
37
+ a = 2
38
+ mag /= a
39
+ rgb = np.zeros((*dP.shape[1:], 3), "uint8")
40
+ rgb[..., 0] = np.clip(mag * (np.cos(angles) + 1), 0, 255).astype("uint8")
41
+ rgb[..., 1] = np.clip(mag * (np.cos(angles + 2 * np.pi / 3) + 1), 0, 255).astype("uint8")
42
+ rgb[..., 2] = np.clip(mag * (np.cos(angles + 4 * np.pi / 3) + 1), 0, 255).astype("uint8")
43
+
44
+ return rgb
45
+
46
+
47
+ def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None):
48
+ """Plot segmentation results (like on website).
49
+
50
+ Can save each panel of figure with file_name option. Use channels option if
51
+ img input is not an RGB image with 3 channels.
52
+
53
+ Args:
54
+ fig (matplotlib.pyplot.figure): Figure in which to make plot.
55
+ img (ndarray): 2D or 3D array. Image input into cellpose.
56
+ maski (int, ndarray): For image k, masks[k] output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
57
+ flowi (int, ndarray): For image k, flows[k][0] output from Cellpose.eval (RGB of flows).
58
+ channels (list of int, optional): Channels used to run Cellpose, no need to use if image is RGB. Defaults to [0, 0].
59
+ file_name (str, optional): File name of image. If file_name is not None, figure panels are saved. Defaults to None.
60
+ seg_norm (bool, optional): Improve cell visibility under labels. Defaults to False.
61
+ """
62
+ if not MATPLOTLIB_ENABLED:
63
+ raise ImportError(
64
+ "matplotlib not installed, install with 'pip install matplotlib'")
65
+ ax = fig.add_subplot(1, 4, 1)
66
+ img0 = img.copy()
67
+
68
+ if img0.shape[0] < 4:
69
+ img0 = np.transpose(img0, (1, 2, 0))
70
+ if img0.shape[-1] < 3 or img0.ndim < 3:
71
+ img0 = image_to_rgb(img0, channels=channels)
72
+ else:
73
+ if img0.max() <= 50.0:
74
+ img0 = np.uint8(np.clip(img0, 0, 1) * 255)
75
+ ax.imshow(img0)
76
+ ax.set_title("original image")
77
+ ax.axis("off")
78
+
79
+ outlines = utils.masks_to_outlines(maski)
80
+
81
+ overlay = mask_overlay(img0, maski)
82
+
83
+ ax = fig.add_subplot(1, 4, 2)
84
+ outX, outY = np.nonzero(outlines)
85
+ imgout = img0.copy()
86
+ imgout[outX, outY] = np.array([255, 0, 0]) # pure red
87
+
88
+ ax.imshow(imgout)
89
+ ax.set_title("predicted outlines")
90
+ ax.axis("off")
91
+
92
+ ax = fig.add_subplot(1, 4, 3)
93
+ ax.imshow(overlay)
94
+ ax.set_title("predicted masks")
95
+ ax.axis("off")
96
+
97
+ ax = fig.add_subplot(1, 4, 4)
98
+ ax.imshow(flowi)
99
+ ax.set_title("predicted cell pose")
100
+ ax.axis("off")
101
+
102
+ if file_name is not None:
103
+ save_path = os.path.splitext(file_name)[0]
104
+ io.imsave(save_path + "_overlay.jpg", overlay)
105
+ io.imsave(save_path + "_outlines.jpg", imgout)
106
+ io.imsave(save_path + "_flows.jpg", flowi)
107
+
108
+
109
+ def mask_rgb(masks, colors=None):
110
+ """Masks in random RGB colors.
111
+
112
+ Args:
113
+ masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
114
+ colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range.
115
+
116
+ Returns:
117
+ RGB (uint8, 3D array): Array of masks overlaid on grayscale image.
118
+ """
119
+ if colors is not None:
120
+ if colors.max() > 1:
121
+ colors = np.float32(colors)
122
+ colors /= 255
123
+ colors = utils.rgb_to_hsv(colors)
124
+
125
+ HSV = np.zeros((masks.shape[0], masks.shape[1], 3), np.float32)
126
+ HSV[:, :, 2] = 1.0
127
+ for n in range(int(masks.max())):
128
+ ipix = (masks == n + 1).nonzero()
129
+ if colors is None:
130
+ HSV[ipix[0], ipix[1], 0] = np.random.rand()
131
+ else:
132
+ HSV[ipix[0], ipix[1], 0] = colors[n, 0]
133
+ HSV[ipix[0], ipix[1], 1] = np.random.rand() * 0.5 + 0.5
134
+ HSV[ipix[0], ipix[1], 2] = np.random.rand() * 0.5 + 0.5
135
+ RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8)
136
+ return RGB
137
+
138
+
139
+ def mask_overlay(img, masks, colors=None):
140
+ """Overlay masks on image (set image to grayscale).
141
+
142
+ Args:
143
+ img (int or float, 2D or 3D array): Image of size [Ly x Lx (x nchan)].
144
+ masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
145
+ colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range.
146
+
147
+ Returns:
148
+ RGB (uint8, 3D array): Array of masks overlaid on grayscale image.
149
+ """
150
+ if colors is not None:
151
+ if colors.max() > 1:
152
+ colors = np.float32(colors)
153
+ colors /= 255
154
+ colors = utils.rgb_to_hsv(colors)
155
+ if img.ndim > 2:
156
+ img = img.astype(np.float32).mean(axis=-1)
157
+ else:
158
+ img = img.astype(np.float32)
159
+
160
+ HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
161
+ HSV[:, :, 2] = np.clip((img / 255. if img.max() > 1 else img) * 1.5, 0, 1)
162
+ hues = np.linspace(0, 1, masks.max() + 1)[np.random.permutation(masks.max())]
163
+ for n in range(int(masks.max())):
164
+ ipix = (masks == n + 1).nonzero()
165
+ if colors is None:
166
+ HSV[ipix[0], ipix[1], 0] = hues[n]
167
+ else:
168
+ HSV[ipix[0], ipix[1], 0] = colors[n, 0]
169
+ HSV[ipix[0], ipix[1], 1] = 1.0
170
+ RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8)
171
+ return RGB
172
+
173
+
174
+ def image_to_rgb(img0, channels=[0, 0]):
175
+ """Converts image from 2 x Ly x Lx or Ly x Lx x 2 to RGB Ly x Lx x 3.
176
+
177
+ Args:
178
+ img0 (ndarray): Input image of shape 2 x Ly x Lx or Ly x Lx x 2.
179
+
180
+ Returns:
181
+ ndarray: RGB image of shape Ly x Lx x 3.
182
+
183
+ """
184
+ img = img0.copy()
185
+ img = img.astype(np.float32)
186
+ if img.ndim < 3:
187
+ img = img[:, :, np.newaxis]
188
+ if img.shape[0] < 5:
189
+ img = np.transpose(img, (1, 2, 0))
190
+ if channels[0] == 0:
191
+ img = img.mean(axis=-1)[:, :, np.newaxis]
192
+ for i in range(img.shape[-1]):
193
+ if np.ptp(img[:, :, i]) > 0:
194
+ img[:, :, i] = np.clip(transforms.normalize99(img[:, :, i]), 0, 1)
195
+ img[:, :, i] = np.clip(img[:, :, i], 0, 1)
196
+ img *= 255
197
+ img = np.uint8(img)
198
+ RGB = np.zeros((img.shape[0], img.shape[1], 3), np.uint8)
199
+ if img.shape[-1] == 1:
200
+ RGB = np.tile(img, (1, 1, 3))
201
+ else:
202
+ RGB[:, :, channels[0] - 1] = img[:, :, 0]
203
+ if channels[1] > 0:
204
+ RGB[:, :, channels[1] - 1] = img[:, :, 1]
205
+ return RGB
206
+
207
+
208
+ def interesting_patch(mask, bsize=130):
209
+ """
210
+ Get patch of size bsize x bsize with most masks.
211
+
212
+ Args:
213
+ mask (ndarray): Input mask.
214
+ bsize (int): Size of the patch.
215
+
216
+ Returns:
217
+ tuple: Patch coordinates (y, x).
218
+
219
+ """
220
+ Ly, Lx = mask.shape
221
+ m = np.float32(mask > 0)
222
+ m = gaussian_filter(m, bsize / 2)
223
+ y, x = np.unravel_index(np.argmax(m), m.shape)
224
+ ycent = max(bsize // 2, min(y, Ly - bsize // 2))
225
+ xcent = max(bsize // 2, min(x, Lx - bsize // 2))
226
+ patch = [
227
+ np.arange(ycent - bsize // 2, ycent + bsize // 2, 1, int),
228
+ np.arange(xcent - bsize // 2, xcent + bsize // 2, 1, int)
229
+ ]
230
+ return patch
231
+
232
+
233
+ def disk(med, r, Ly, Lx):
234
+ """Returns the pixels of a disk with a given radius and center.
235
+
236
+ Args:
237
+ med (tuple): The center coordinates of the disk.
238
+ r (float): The radius of the disk.
239
+ Ly (int): The height of the image.
240
+ Lx (int): The width of the image.
241
+
242
+ Returns:
243
+ tuple: A tuple containing the y and x coordinates of the pixels within the disk.
244
+
245
+ """
246
+ yy, xx = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
247
+ indexing="ij")
248
+ inds = ((yy - med[0])**2 + (xx - med[1])**2)**0.5 <= r
249
+ y = yy[inds].flatten()
250
+ x = xx[inds].flatten()
251
+ return y, x
252
+
253
+
254
+ def outline_view(img0, maski, color=[1, 0, 0], mode="inner"):
255
+ """
256
+ Generates a red outline overlay onto the image.
257
+
258
+ Args:
259
+ img0 (numpy.ndarray): The input image.
260
+ maski (numpy.ndarray): The mask representing the region of interest.
261
+ color (list, optional): The color of the outline overlay. Defaults to [1, 0, 0] (red).
262
+ mode (str, optional): The mode for generating the outline. Defaults to "inner".
263
+
264
+ Returns:
265
+ numpy.ndarray: The image with the red outline overlay.
266
+
267
+ """
268
+ if img0.ndim == 2:
269
+ img0 = np.stack([img0] * 3, axis=-1)
270
+ elif img0.ndim != 3:
271
+ raise ValueError("img0 not right size (must have ndim 2 or 3)")
272
+
273
+ if SKIMAGE_ENABLED:
274
+ outlines = find_boundaries(maski, mode=mode)
275
+ else:
276
+ outlines = utils.masks_to_outlines(maski, mode=mode)
277
+ outY, outX = np.nonzero(outlines)
278
+ imgout = img0.copy()
279
+ imgout[outY, outX] = np.array(color)
280
+
281
+ return imgout
models/seg_post_model/cellpose/transforms.py ADDED
@@ -0,0 +1,1261 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import logging
5
+
6
+ import cv2
7
+ import numpy as np
8
+ import torch
9
+ from scipy.ndimage import gaussian_filter1d
10
+ from torch.fft import fft2, fftshift, ifft2
11
+
12
+ transforms_logger = logging.getLogger(__name__)
13
+
14
+
15
+ def _taper_mask(ly=224, lx=224, sig=7.5):
16
+ """
17
+ Generate a taper mask.
18
+
19
+ Args:
20
+ ly (int): The height of the mask. Default is 224.
21
+ lx (int): The width of the mask. Default is 224.
22
+ sig (float): The sigma value for the tapering function. Default is 7.5.
23
+
24
+ Returns:
25
+ numpy.ndarray: The taper mask.
26
+
27
+ """
28
+ bsize = max(224, max(ly, lx))
29
+ xm = np.arange(bsize)
30
+ xm = np.abs(xm - xm.mean())
31
+ mask = 1 / (1 + np.exp((xm - (bsize / 2 - 20)) / sig))
32
+ mask = mask * mask[:, np.newaxis]
33
+ mask = mask[bsize // 2 - ly // 2:bsize // 2 + ly // 2 + ly % 2,
34
+ bsize // 2 - lx // 2:bsize // 2 + lx // 2 + lx % 2]
35
+ return mask
36
+
37
+
38
+ def unaugment_tiles(y):
39
+ """Reverse test-time augmentations for averaging (includes flipping of flowsY and flowsX).
40
+
41
+ Args:
42
+ y (float32): Array of shape (ntiles_y, ntiles_x, chan, Ly, Lx) where chan = (flowsY, flowsX, cell prob).
43
+
44
+ Returns:
45
+ float32: Array of shape (ntiles_y, ntiles_x, chan, Ly, Lx).
46
+
47
+ """
48
+ for j in range(y.shape[0]):
49
+ for i in range(y.shape[1]):
50
+ if j % 2 == 0 and i % 2 == 1:
51
+ y[j, i] = y[j, i, :, ::-1, :]
52
+ y[j, i, 0] *= -1
53
+ elif j % 2 == 1 and i % 2 == 0:
54
+ y[j, i] = y[j, i, :, :, ::-1]
55
+ y[j, i, 1] *= -1
56
+ elif j % 2 == 1 and i % 2 == 1:
57
+ y[j, i] = y[j, i, :, ::-1, ::-1]
58
+ y[j, i, 0] *= -1
59
+ y[j, i, 1] *= -1
60
+ return y
61
+
62
+
63
+ def average_tiles(y, ysub, xsub, Ly, Lx):
64
+ """
65
+ Average the results of the network over tiles.
66
+
67
+ Args:
68
+ y (float): Output of cellpose network for each tile. Shape: [ntiles x nclasses x bsize x bsize]
69
+ ysub (list): List of arrays with start and end of tiles in Y of length ntiles
70
+ xsub (list): List of arrays with start and end of tiles in X of length ntiles
71
+ Ly (int): Size of pre-tiled image in Y (may be larger than original image if image size is less than bsize)
72
+ Lx (int): Size of pre-tiled image in X (may be larger than original image if image size is less than bsize)
73
+
74
+ Returns:
75
+ yf (float32): Network output averaged over tiles. Shape: [nclasses x Ly x Lx]
76
+ """
77
+ Navg = np.zeros((Ly, Lx))
78
+ yf = np.zeros((y.shape[1], Ly, Lx), np.float32)
79
+ # taper edges of tiles
80
+ mask = _taper_mask(ly=y.shape[-2], lx=y.shape[-1])
81
+ for j in range(len(ysub)):
82
+ yf[:, ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += y[j] * mask
83
+ Navg[ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += mask
84
+ yf /= Navg
85
+ return yf
86
+
87
+
88
+ def make_tiles(imgi, bsize=224, augment=False, tile_overlap=0.1):
89
+ """Make tiles of image to run at test-time.
90
+
91
+ Args:
92
+ imgi (np.ndarray): Array of shape (nchan, Ly, Lx) representing the input image.
93
+ bsize (int, optional): Size of tiles. Defaults to 224.
94
+ augment (bool, optional): Whether to flip tiles and set tile_overlap=2. Defaults to False.
95
+ tile_overlap (float, optional): Fraction of overlap of tiles. Defaults to 0.1.
96
+
97
+ Returns:
98
+ A tuple containing (IMG, ysub, xsub, Ly, Lx):
99
+ IMG (np.ndarray): Array of shape (ntiles, nchan, bsize, bsize) representing the tiles.
100
+ ysub (list): List of arrays with start and end of tiles in Y of length ntiles.
101
+ xsub (list): List of arrays with start and end of tiles in X of length ntiles.
102
+ Ly (int): Height of the input image.
103
+ Lx (int): Width of the input image.
104
+ """
105
+ nchan, Ly, Lx = imgi.shape
106
+ if augment:
107
+ bsize = np.int32(bsize)
108
+ # pad if image smaller than bsize
109
+ if Ly < bsize:
110
+ imgi = np.concatenate((imgi, np.zeros((nchan, bsize - Ly, Lx))), axis=1)
111
+ Ly = bsize
112
+ if Lx < bsize:
113
+ imgi = np.concatenate((imgi, np.zeros((nchan, Ly, bsize - Lx))), axis=2)
114
+ Ly, Lx = imgi.shape[-2:]
115
+
116
+ # tiles overlap by half of tile size
117
+ ny = max(2, int(np.ceil(2. * Ly / bsize)))
118
+ nx = max(2, int(np.ceil(2. * Lx / bsize)))
119
+ ystart = np.linspace(0, Ly - bsize, ny).astype(int)
120
+ xstart = np.linspace(0, Lx - bsize, nx).astype(int)
121
+
122
+ ysub = []
123
+ xsub = []
124
+
125
+ # flip tiles so that overlapping segments are processed in rotation
126
+ IMG = np.zeros((len(ystart), len(xstart), nchan, bsize, bsize), np.float32)
127
+ for j in range(len(ystart)):
128
+ for i in range(len(xstart)):
129
+ ysub.append([ystart[j], ystart[j] + bsize])
130
+ xsub.append([xstart[i], xstart[i] + bsize])
131
+ IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]]
132
+ # flip tiles to allow for augmentation of overlapping segments
133
+ if j % 2 == 0 and i % 2 == 1:
134
+ IMG[j, i] = IMG[j, i, :, ::-1, :]
135
+ elif j % 2 == 1 and i % 2 == 0:
136
+ IMG[j, i] = IMG[j, i, :, :, ::-1]
137
+ elif j % 2 == 1 and i % 2 == 1:
138
+ IMG[j, i] = IMG[j, i, :, ::-1, ::-1]
139
+ else:
140
+ tile_overlap = min(0.5, max(0.05, tile_overlap))
141
+ bsizeY, bsizeX = min(bsize, Ly), min(bsize, Lx)
142
+ bsizeY = np.int32(bsizeY)
143
+ bsizeX = np.int32(bsizeX)
144
+ # tiles overlap by 10% tile size
145
+ ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize))
146
+ nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize))
147
+ ystart = np.linspace(0, Ly - bsizeY, ny).astype(int)
148
+ xstart = np.linspace(0, Lx - bsizeX, nx).astype(int)
149
+
150
+ ysub = []
151
+ xsub = []
152
+ IMG = np.zeros((len(ystart), len(xstart), nchan, bsizeY, bsizeX), np.float32)
153
+ for j in range(len(ystart)):
154
+ for i in range(len(xstart)):
155
+ ysub.append([ystart[j], ystart[j] + bsizeY])
156
+ xsub.append([xstart[i], xstart[i] + bsizeX])
157
+ IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]]
158
+
159
+ return IMG, ysub, xsub, Ly, Lx
160
+
161
+
162
+ def normalize99(Y, lower=1, upper=99, copy=True, downsample=False):
163
+ """
164
+ Normalize the image so that 0.0 corresponds to the 1st percentile and 1.0 corresponds to the 99th percentile.
165
+
166
+ Args:
167
+ Y (ndarray): The input image (for downsample, use [Ly x Lx] or [Lz x Ly x Lx]).
168
+ lower (int, optional): The lower percentile. Defaults to 1.
169
+ upper (int, optional): The upper percentile. Defaults to 99.
170
+ copy (bool, optional): Whether to create a copy of the input image. Defaults to True.
171
+ downsample (bool, optional): Whether to downsample image to compute percentiles. Defaults to False.
172
+
173
+ Returns:
174
+ ndarray: The normalized image.
175
+ """
176
+ X = Y.copy() if copy else Y
177
+ X = X.astype("float32") if X.dtype!="float64" and X.dtype!="float32" else X
178
+ if downsample and X.size > 224**3:
179
+ nskip = [max(1, X.shape[i] // 224) for i in range(X.ndim)]
180
+ nskip[0] = max(1, X.shape[0] // 50) if X.ndim == 3 else nskip[0]
181
+ slc = tuple([slice(0, X.shape[i], nskip[i]) for i in range(X.ndim)])
182
+ x01 = np.percentile(X[slc], lower)
183
+ x99 = np.percentile(X[slc], upper)
184
+ else:
185
+ x01 = np.percentile(X, lower)
186
+ x99 = np.percentile(X, upper)
187
+ if x99 - x01 > 1e-3:
188
+ X -= x01
189
+ X /= (x99 - x01)
190
+ else:
191
+ X[:] = 0
192
+ return X
193
+
194
+
195
+ def normalize99_tile(img, blocksize=100, lower=1., upper=99., tile_overlap=0.1,
196
+ norm3D=False, smooth3D=1, is3D=False):
197
+ """Compute normalization like normalize99 function but in tiles.
198
+
199
+ Args:
200
+ img (numpy.ndarray): Array of shape (Lz x) Ly x Lx (x nchan) containing the image.
201
+ blocksize (float, optional): Size of tiles. Defaults to 100.
202
+ lower (float, optional): Lower percentile for normalization. Defaults to 1.0.
203
+ upper (float, optional): Upper percentile for normalization. Defaults to 99.0.
204
+ tile_overlap (float, optional): Fraction of overlap of tiles. Defaults to 0.1.
205
+ norm3D (bool, optional): Use same tiled normalization for each z-plane. Defaults to False.
206
+ smooth3D (int, optional): Smoothing factor for 3D normalization. Defaults to 1.
207
+ is3D (bool, optional): Set to True if image is a 3D stack. Defaults to False.
208
+
209
+ Returns:
210
+ numpy.ndarray: Normalized image array of shape (Lz x) Ly x Lx (x nchan).
211
+ """
212
+ is1c = True if img.ndim == 2 or (is3D and img.ndim == 3) else False
213
+ is3D = True if img.ndim > 3 or (is3D and img.ndim == 3) else False
214
+ img = img[..., np.newaxis] if is1c else img
215
+ img = img[np.newaxis, ...] if img.ndim == 3 else img
216
+ Lz, Ly, Lx, nchan = img.shape
217
+
218
+ tile_overlap = min(0.5, max(0.05, tile_overlap))
219
+ blocksizeY, blocksizeX = min(blocksize, Ly), min(blocksize, Lx)
220
+ blocksizeY = np.int32(blocksizeY)
221
+ blocksizeX = np.int32(blocksizeX)
222
+ # tiles overlap by 10% tile size
223
+ ny = 1 if Ly <= blocksize else int(np.ceil(
224
+ (1. + 2 * tile_overlap) * Ly / blocksize))
225
+ nx = 1 if Lx <= blocksize else int(np.ceil(
226
+ (1. + 2 * tile_overlap) * Lx / blocksize))
227
+ ystart = np.linspace(0, Ly - blocksizeY, ny).astype(int)
228
+ xstart = np.linspace(0, Lx - blocksizeX, nx).astype(int)
229
+ ysub = []
230
+ xsub = []
231
+ for j in range(len(ystart)):
232
+ for i in range(len(xstart)):
233
+ ysub.append([ystart[j], ystart[j] + blocksizeY])
234
+ xsub.append([xstart[i], xstart[i] + blocksizeX])
235
+
236
+ x01_tiles_z = []
237
+ x99_tiles_z = []
238
+ for z in range(Lz):
239
+ IMG = np.zeros((len(ystart), len(xstart), blocksizeY, blocksizeX, nchan),
240
+ "float32")
241
+ k = 0
242
+ for j in range(len(ystart)):
243
+ for i in range(len(xstart)):
244
+ IMG[j, i] = img[z, ysub[k][0]:ysub[k][1], xsub[k][0]:xsub[k][1], :]
245
+ k += 1
246
+ x01_tiles = np.percentile(IMG, lower, axis=(-3, -2))
247
+ x99_tiles = np.percentile(IMG, upper, axis=(-3, -2))
248
+
249
+ # fill areas with small differences with neighboring squares
250
+ to_fill = np.zeros(x01_tiles.shape[:2], "bool")
251
+ for c in range(nchan):
252
+ to_fill = x99_tiles[:, :, c] - x01_tiles[:, :, c] < +1e-3
253
+ if to_fill.sum() > 0 and to_fill.sum() < x99_tiles[:, :, c].size:
254
+ fill_vals = np.nonzero(to_fill)
255
+ fill_neigh = np.nonzero(~to_fill)
256
+ nearest_neigh = (
257
+ (fill_vals[0] - fill_neigh[0][:, np.newaxis])**2 +
258
+ (fill_vals[1] - fill_neigh[1][:, np.newaxis])**2).argmin(axis=0)
259
+ x01_tiles[fill_vals[0], fill_vals[1],
260
+ c] = x01_tiles[fill_neigh[0][nearest_neigh],
261
+ fill_neigh[1][nearest_neigh], c]
262
+ x99_tiles[fill_vals[0], fill_vals[1],
263
+ c] = x99_tiles[fill_neigh[0][nearest_neigh],
264
+ fill_neigh[1][nearest_neigh], c]
265
+ elif to_fill.sum() > 0 and to_fill.sum() == x99_tiles[:, :, c].size:
266
+ x01_tiles[:, :, c] = 0
267
+ x99_tiles[:, :, c] = 1
268
+ x01_tiles_z.append(x01_tiles)
269
+ x99_tiles_z.append(x99_tiles)
270
+
271
+ x01_tiles_z = np.array(x01_tiles_z)
272
+ x99_tiles_z = np.array(x99_tiles_z)
273
+ # do not smooth over z-axis if not normalizing separately per plane
274
+ for a in range(2):
275
+ x01_tiles_z = gaussian_filter1d(x01_tiles_z, 1, axis=a)
276
+ x99_tiles_z = gaussian_filter1d(x99_tiles_z, 1, axis=a)
277
+ if norm3D:
278
+ smooth3D = 1 if smooth3D == 0 else smooth3D
279
+ x01_tiles_z = gaussian_filter1d(x01_tiles_z, smooth3D, axis=a)
280
+ x99_tiles_z = gaussian_filter1d(x99_tiles_z, smooth3D, axis=a)
281
+
282
+ if not norm3D and Lz > 1:
283
+ x01 = np.zeros((len(x01_tiles_z), Ly, Lx, nchan), "float32")
284
+ x99 = np.zeros((len(x01_tiles_z), Ly, Lx, nchan), "float32")
285
+ for z in range(Lz):
286
+ x01_rsz = cv2.resize(x01_tiles_z[z], (Lx, Ly),
287
+ interpolation=cv2.INTER_LINEAR)
288
+ x01[z] = x01_rsz[..., np.newaxis] if nchan == 1 else x01_rsz
289
+ x99_rsz = cv2.resize(x99_tiles_z[z], (Lx, Ly),
290
+ interpolation=cv2.INTER_LINEAR)
291
+ x99[z] = x99_rsz[..., np.newaxis] if nchan == 1 else x01_rsz
292
+ if (x99 - x01).min() < 1e-3:
293
+ raise ZeroDivisionError(
294
+ "cannot use norm3D=False with tile_norm, sample is too sparse; set norm3D=True or tile_norm=0"
295
+ )
296
+ else:
297
+ x01 = cv2.resize(x01_tiles_z.mean(axis=0), (Lx, Ly),
298
+ interpolation=cv2.INTER_LINEAR)
299
+ x99 = cv2.resize(x99_tiles_z.mean(axis=0), (Lx, Ly),
300
+ interpolation=cv2.INTER_LINEAR)
301
+ if x01.ndim < 3:
302
+ x01 = x01[..., np.newaxis]
303
+ x99 = x99[..., np.newaxis]
304
+
305
+ if is1c:
306
+ img, x01, x99 = img.squeeze(), x01.squeeze(), x99.squeeze()
307
+ elif not is3D:
308
+ img, x01, x99 = img[0], x01[0], x99[0]
309
+
310
+ # normalize
311
+ img -= x01
312
+ img /= (x99 - x01)
313
+
314
+ return img
315
+
316
+
317
+ def gaussian_kernel(sigma, Ly, Lx, device=torch.device("cpu")):
318
+ """
319
+ Generates a 2D Gaussian kernel.
320
+
321
+ Args:
322
+ sigma (float): Standard deviation of the Gaussian distribution.
323
+ Ly (int): Number of pixels in the y-axis.
324
+ Lx (int): Number of pixels in the x-axis.
325
+ device (torch.device, optional): Device to store the kernel tensor. Defaults to torch.device("cpu").
326
+
327
+ Returns:
328
+ torch.Tensor: 2D Gaussian kernel tensor.
329
+
330
+ """
331
+ y = torch.linspace(-Ly / 2, Ly / 2 + 1, Ly, device=device)
332
+ x = torch.linspace(-Ly / 2, Ly / 2 + 1, Lx, device=device)
333
+ y, x = torch.meshgrid(y, x, indexing="ij")
334
+ kernel = torch.exp(-(y**2 + x**2) / (2 * sigma**2))
335
+ kernel /= kernel.sum()
336
+ return kernel
337
+
338
+
339
+ def smooth_sharpen_img(img, smooth_radius=6, sharpen_radius=12,
340
+ device=torch.device("cpu"), is3D=False):
341
+ """Sharpen blurry images with surround subtraction and/or smooth noisy images.
342
+
343
+ Args:
344
+ img (float32): Array that's (Lz x) Ly x Lx (x nchan).
345
+ smooth_radius (float, optional): Size of gaussian smoothing filter, recommended to be 1/10-1/4 of cell diameter
346
+ (if also sharpening, should be 2-3x smaller than sharpen_radius). Defaults to 6.
347
+ sharpen_radius (float, optional): Size of gaussian surround filter, recommended to be 1/8-1/2 of cell diameter
348
+ (if also smoothing, should be 2-3x larger than smooth_radius). Defaults to 12.
349
+ device (torch.device, optional): Device on which to perform sharpening.
350
+ Will be faster on GPU but need to ensure GPU has RAM for image. Defaults to torch.device("cpu").
351
+ is3D (bool, optional): If image is 3D stack (only necessary to set if img.ndim==3). Defaults to False.
352
+
353
+ Returns:
354
+ img_sharpen (float32): Array that's (Lz x) Ly x Lx (x nchan).
355
+ """
356
+ img_sharpen = torch.from_numpy(img.astype("float32")).to(device)
357
+ shape = img_sharpen.shape
358
+
359
+ is1c = True if img_sharpen.ndim == 2 or (is3D and img_sharpen.ndim == 3) else False
360
+ is3D = True if img_sharpen.ndim > 3 or (is3D and img_sharpen.ndim == 3) else False
361
+ img_sharpen = img_sharpen.unsqueeze(-1) if is1c else img_sharpen
362
+ img_sharpen = img_sharpen.unsqueeze(0) if img_sharpen.ndim == 3 else img_sharpen
363
+ Lz, Ly, Lx, nchan = img_sharpen.shape
364
+
365
+ if smooth_radius > 0:
366
+ kernel = gaussian_kernel(smooth_radius, Ly, Lx, device=device)
367
+ if sharpen_radius > 0:
368
+ kernel += -1 * gaussian_kernel(sharpen_radius, Ly, Lx, device=device)
369
+ elif sharpen_radius > 0:
370
+ kernel = -1 * gaussian_kernel(sharpen_radius, Ly, Lx, device=device)
371
+ kernel[Ly // 2, Lx // 2] = 1
372
+
373
+ fhp = fft2(kernel)
374
+ for z in range(Lz):
375
+ for c in range(nchan):
376
+ img_filt = torch.real(ifft2(
377
+ fft2(img_sharpen[z, :, :, c]) * torch.conj(fhp)))
378
+ img_filt = fftshift(img_filt)
379
+ img_sharpen[z, :, :, c] = img_filt
380
+
381
+ img_sharpen = img_sharpen.reshape(shape)
382
+ return img_sharpen.cpu().numpy()
383
+
384
+
385
+ def move_axis(img, m_axis=-1, first=True):
386
+ """ move axis m_axis to first or last position """
387
+ if m_axis == -1:
388
+ m_axis = img.ndim - 1
389
+ m_axis = min(img.ndim - 1, m_axis)
390
+ axes = np.arange(0, img.ndim)
391
+ if first:
392
+ axes[1:m_axis + 1] = axes[:m_axis]
393
+ axes[0] = m_axis
394
+ else:
395
+ axes[m_axis:-1] = axes[m_axis + 1:]
396
+ axes[-1] = m_axis
397
+ img = img.transpose(tuple(axes))
398
+ return img
399
+
400
+
401
+ def move_min_dim(img, force=False):
402
+ """Move the minimum dimension last as channels if it is less than 10 or force is True.
403
+
404
+ Args:
405
+ img (ndarray): The input image.
406
+ force (bool, optional): If True, the minimum dimension will always be moved.
407
+ Defaults to False.
408
+
409
+ Returns:
410
+ ndarray: The image with the minimum dimension moved to the last axis as channels.
411
+ """
412
+ if len(img.shape) > 2:
413
+ min_dim = min(img.shape)
414
+ if min_dim < 10 or force:
415
+ if img.shape[-1] == min_dim:
416
+ channel_axis = -1
417
+ else:
418
+ channel_axis = (img.shape).index(min_dim)
419
+ img = move_axis(img, m_axis=channel_axis, first=False)
420
+ return img
421
+
422
+
423
+ def update_axis(m_axis, to_squeeze, ndim):
424
+ """
425
+ Squeeze the axis value based on the given parameters.
426
+
427
+ Args:
428
+ m_axis (int): The current axis value.
429
+ to_squeeze (numpy.ndarray): An array of indices to squeeze.
430
+ ndim (int): The number of dimensions.
431
+
432
+ Returns:
433
+ m_axis (int or None): The updated axis value.
434
+ """
435
+ if m_axis == -1:
436
+ m_axis = ndim - 1
437
+ if (to_squeeze == m_axis).sum() == 1:
438
+ m_axis = None
439
+ else:
440
+ inds = np.ones(ndim, bool)
441
+ inds[to_squeeze] = False
442
+ m_axis = np.nonzero(np.arange(0, ndim)[inds] == m_axis)[0]
443
+ if len(m_axis) > 0:
444
+ m_axis = m_axis[0]
445
+ else:
446
+ m_axis = None
447
+ return m_axis
448
+
449
+
450
+ def _convert_image_3d(x, channel_axis=None, z_axis=None):
451
+ """
452
+ Convert a 3D or 4D image array to have dimensions ordered as (Z, X, Y, C).
453
+
454
+ Arrays of ndim=3 are assumed to be grayscale and must be specified with z_axis.
455
+ Arrays of ndim=4 must have both `channel_axis` and `z_axis` specified.
456
+
457
+ Args:
458
+ x (numpy.ndarray): Input image array. Must be either 3D (assumed to be grayscale 3D) or 4D.
459
+ channel_axis (int): The axis index corresponding to the channel dimension in the input array. \
460
+ Must be specified for 4D images.
461
+ z_axis (int): The axis index corresponding to the depth (Z) dimension in the input array. \
462
+ Must be specified for both 3D and 4D images.
463
+
464
+ Returns:
465
+ numpy.ndarray: A 4D image array with dimensions ordered as (Z, X, Y, C), where C is the channel
466
+ dimension. If the input has fewer than 3 channels, the output will be padded with zeros to \
467
+ have 3 channels. If the input has more than 3 channels, only the first 3 channels will be retained.
468
+
469
+ Raises:
470
+ ValueError: If `z_axis` is not specified for 3D images. If either `channel_axis` or `z_axis` \
471
+ is not specified for 4D images. If the input image does not have 3 or 4 dimensions.
472
+
473
+ Notes:
474
+ - For 3D images (ndim=3), the function assumes the input is grayscale and adds a singleton channel dimension.
475
+ - The function reorders the dimensions of the input array to ensure the output has the desired (Z, X, Y, C) order.
476
+ - If the number of channels is not equal to 3, the function either truncates or pads the \
477
+ channels to ensure the output has exactly 3 channels.
478
+ """
479
+
480
+ if x.ndim < 3:
481
+ raise ValueError(f"Input image must have at least 3 dimensions, input shape: {x.shape}, ndim={x.ndim}")
482
+
483
+ if z_axis is not None and z_axis < 0:
484
+ z_axis += x.ndim
485
+
486
+ # if image is ndim==3, assume it is greyscale 3D and use provided z_axis
487
+ if x.ndim == 3 and z_axis is not None:
488
+ # add in channel axis
489
+ x = x[..., np.newaxis]
490
+ channel_axis = 3
491
+ elif x.ndim == 3 and z_axis is None:
492
+ raise ValueError("z_axis must be specified when segmenting 3D images of ndim=3")
493
+
494
+
495
+ if channel_axis is None or z_axis is None:
496
+ raise ValueError("For 4D images, both `channel_axis` and `z_axis` must be explicitly specified. Please provide values for both parameters.")
497
+ if channel_axis is not None and channel_axis < 0:
498
+ channel_axis += x.ndim
499
+ if channel_axis is None or channel_axis >= x.ndim:
500
+ raise IndexError(f"channel_axis {channel_axis} is out of bounds for input array with {x.ndim} dimensions")
501
+ assert x.ndim == 4, f"input image must have ndim == 4, ndim={x.ndim}"
502
+
503
+ x_dim_shapes = list(x.shape)
504
+ num_z_layers = x_dim_shapes[z_axis]
505
+ num_channels = x_dim_shapes[channel_axis]
506
+ x_xy_axes = [i for i in range(x.ndim)]
507
+
508
+ # need to remove the z and channels from the shapes:
509
+ # delete the one with the bigger index first
510
+ if z_axis > channel_axis:
511
+ del x_dim_shapes[z_axis]
512
+ del x_dim_shapes[channel_axis]
513
+
514
+ del x_xy_axes[z_axis]
515
+ del x_xy_axes[channel_axis]
516
+
517
+ else:
518
+ del x_dim_shapes[channel_axis]
519
+ del x_dim_shapes[z_axis]
520
+
521
+ del x_xy_axes[channel_axis]
522
+ del x_xy_axes[z_axis]
523
+
524
+ x = x.transpose((z_axis, x_xy_axes[0], x_xy_axes[1], channel_axis))
525
+
526
+ # Handle cases with not 3 channels:
527
+ if num_channels != 3:
528
+ x_chans_to_copy = min(3, num_channels)
529
+
530
+ if num_channels > 3:
531
+ transforms_logger.warning("more than 3 channels provided, only segmenting on first 3 channels")
532
+ x = x[..., :x_chans_to_copy]
533
+ else:
534
+ # less than 3 channels: pad up to
535
+ pad_width = [(0, 0), (0, 0), (0, 0), (0, 3 - x_chans_to_copy)]
536
+ x = np.pad(x, pad_width, mode='constant', constant_values=0)
537
+
538
+ return x
539
+
540
+
541
+ def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
542
+ """Converts the image to have the z-axis first, channels last. Image will be converted to 3 channels if it is not already.
543
+ If more than 3 channels are provided, only the first 3 channels will be used.
544
+
545
+ Accepts:
546
+ - 2D images with no channel dimension: `z_axis` and `channel_axis` must be `None`
547
+ - 2D images with channel dimension: `channel_axis` will be guessed between first or last axis, can also specify `channel_axis`. `z_axis` must be `None`
548
+ - 3D images with or without channels:
549
+
550
+ Args:
551
+ x (numpy.ndarray or torch.Tensor): The input image.
552
+ channel_axis (int or None): The axis of the channels in the input image. If None, the axis is determined automatically.
553
+ z_axis (int or None): The axis of the z-dimension in the input image. If None, the axis is determined automatically.
554
+ do_3D (bool): Whether to process the image in 3D mode. Defaults to False.
555
+
556
+ Returns:
557
+ numpy.ndarray: The converted image.
558
+
559
+ Raises:
560
+ ValueError: If the input image is 2D and do_3D is True.
561
+ ValueError: If the input image is 4D and do_3D is False.
562
+ """
563
+
564
+ # check if image is a torch array instead of numpy array, convert to numpy
565
+ ndim = x.ndim
566
+ if torch.is_tensor(x):
567
+ transforms_logger.warning("torch array used as input, converting to numpy")
568
+ x = x.cpu().numpy()
569
+
570
+ # should be 2D
571
+ if z_axis is not None and not do_3D:
572
+ raise ValueError("2D image provided, but z_axis is not None. Set z_axis=None to process 2D images of ndim=2 or 3.")
573
+
574
+ # make sure that channel_axis and z_axis are specified if 3D
575
+ if ndim == 4 and not do_3D:
576
+ raise ValueError("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images. ndims=4")
577
+
578
+ # make sure that channel_axis and z_axis are specified if 3D
579
+ if do_3D:
580
+ return _convert_image_3d(x, channel_axis=channel_axis, z_axis=z_axis)
581
+
582
+ ######################## 2D reshaping ########################
583
+ # if user specifies channel axis, return early
584
+ if channel_axis is not None:
585
+ if ndim == 2:
586
+ raise ValueError("2D image provided, but channel_axis is not None. Set channel_axis=None to process 2D images of ndim=2.")
587
+
588
+ # Put channel axis last:
589
+ # Find the indices of the dims that need to be put in dim 0 and 1
590
+ n_channels = x.shape[channel_axis]
591
+ x_shape_dims = list(x.shape)
592
+ del x_shape_dims[channel_axis]
593
+ dimension_indicies = [i for i in range(x.ndim)]
594
+ del dimension_indicies[channel_axis]
595
+
596
+ x = x.transpose((dimension_indicies[0], dimension_indicies[1], channel_axis))
597
+
598
+ if n_channels != 3:
599
+ x_chans_to_copy = min(3, n_channels)
600
+
601
+ if n_channels > 3:
602
+ transforms_logger.warning("more than 3 channels provided, only segmenting on first 3 channels")
603
+ x = x[..., :x_chans_to_copy]
604
+ else:
605
+ x_out = np.zeros((x_shape_dims[0], x_shape_dims[1], 3), dtype=x.dtype)
606
+ x_out[..., :x_chans_to_copy] = x[...]
607
+ x = x_out
608
+ del x_out
609
+
610
+ return x
611
+
612
+ # do image padding and channel conversion
613
+ if ndim == 2:
614
+ # grayscale image, make 3 channels
615
+ x_out = np.zeros((x.shape[0], x.shape[1], 3), dtype=x.dtype)
616
+ x_out[..., 0] = x
617
+ x = x_out
618
+ del x_out
619
+ elif ndim == 3:
620
+ # assume 2d with channels
621
+ # find dim with smaller size between first and last dims
622
+ move_channel_axis = x.shape[0] < x.shape[2]
623
+ if move_channel_axis:
624
+ x = x.transpose((1, 2, 0))
625
+
626
+ # zero padding up to 3 channels:
627
+ num_channels = x.shape[-1]
628
+ if num_channels > 3:
629
+ transforms_logger.warning("Found more than 3 channels, only using first 3")
630
+ num_channels = 3
631
+ x_out = np.zeros((x.shape[0], x.shape[1], 3), dtype=x.dtype)
632
+ x_out[..., :num_channels] = x[..., :num_channels]
633
+ x = x_out
634
+ del x_out
635
+ else:
636
+ # something is wrong: yell
637
+ expected_shapes = "2D (H, W), 3D (H, W, C), or 4D (Z, H, W, C)"
638
+ transforms_logger.critical(f"ERROR: Unexpected image shape: {str(x.shape)}. Expected shapes: {expected_shapes}")
639
+ raise ValueError(f"ERROR: Unexpected image shape: {str(x.shape)}. Expected shapes: {expected_shapes}")
640
+
641
+ return x
642
+
643
+
644
+ def normalize_img(img, normalize=True, norm3D=True, invert=False, lowhigh=None,
645
+ percentile=(1., 99.), sharpen_radius=0, smooth_radius=0,
646
+ tile_norm_blocksize=0, tile_norm_smooth3D=1, axis=-1):
647
+ """Normalize each channel of the image with optional inversion, smoothing, and sharpening.
648
+
649
+ Args:
650
+ img (ndarray): The input image. It should have at least 3 dimensions.
651
+ If it is 4-dimensional, it assumes the first non-channel axis is the Z dimension.
652
+ normalize (bool, optional): Whether to perform normalization. Defaults to True.
653
+ norm3D (bool, optional): Whether to normalize in 3D. If True, the entire 3D stack will
654
+ be normalized per channel. If False, normalization is applied per Z-slice. Defaults to False.
655
+ invert (bool, optional): Whether to invert the image. Useful if cells are dark instead of bright.
656
+ Defaults to False.
657
+ lowhigh (tuple or ndarray, optional): The lower and upper bounds for normalization.
658
+ Can be a tuple of two values (applied to all channels) or an array of shape (nchan, 2)
659
+ for per-channel normalization. Incompatible with smoothing and sharpening.
660
+ Defaults to None.
661
+ percentile (tuple, optional): The lower and upper percentiles for normalization. If provided, it should be
662
+ a tuple of two values. Each value should be between 0 and 100. Defaults to (1.0, 99.0).
663
+ sharpen_radius (int, optional): The radius for sharpening the image. Defaults to 0.
664
+ smooth_radius (int, optional): The radius for smoothing the image. Defaults to 0.
665
+ tile_norm_blocksize (int, optional): The block size for tile-based normalization. Defaults to 0.
666
+ tile_norm_smooth3D (int, optional): The smoothness factor for tile-based normalization in 3D. Defaults to 1.
667
+ axis (int, optional): The channel axis to loop over for normalization. Defaults to -1.
668
+
669
+ Returns:
670
+ ndarray: The normalized image of the same size.
671
+
672
+ Raises:
673
+ ValueError: If the image has less than 3 dimensions.
674
+ ValueError: If the provided lowhigh or percentile values are invalid.
675
+ ValueError: If the image is inverted without normalization.
676
+
677
+ """
678
+ if img.ndim < 3:
679
+ error_message = "Image needs to have at least 3 dimensions"
680
+ transforms_logger.critical(error_message)
681
+ raise ValueError(error_message)
682
+
683
+ img_norm = img if img.dtype=="float32" else img.astype(np.float32)
684
+ if axis != -1 and axis != img_norm.ndim - 1:
685
+ img_norm = np.moveaxis(img_norm, axis, -1) # Move channel axis to last
686
+
687
+ nchan = img_norm.shape[-1]
688
+
689
+ # Validate and handle lowhigh bounds
690
+ if lowhigh is not None:
691
+ lowhigh = np.array(lowhigh)
692
+ if lowhigh.shape == (2,):
693
+ lowhigh = np.tile(lowhigh, (nchan, 1)) # Expand to per-channel bounds
694
+ elif lowhigh.shape != (nchan, 2):
695
+ error_message = "`lowhigh` must have shape (2,) or (nchan, 2)"
696
+ transforms_logger.critical(error_message)
697
+ raise ValueError(error_message)
698
+
699
+ # Validate percentile
700
+ if percentile is None:
701
+ percentile = (1.0, 99.0)
702
+ elif not (0 <= percentile[0] < percentile[1] <= 100):
703
+ error_message = "Invalid percentile range, should be between 0 and 100"
704
+ transforms_logger.critical(error_message)
705
+ raise ValueError(error_message)
706
+
707
+ # Apply normalization based on lowhigh or percentile
708
+ cgood = np.zeros(nchan, "bool")
709
+ if lowhigh is not None:
710
+ for c in range(nchan):
711
+ lower = lowhigh[c, 0]
712
+ upper = lowhigh[c, 1]
713
+ img_norm[..., c] -= lower
714
+ img_norm[..., c] /= (upper - lower)
715
+ cgood[c] = True
716
+ else:
717
+ # Apply sharpening and smoothing if specified
718
+ if sharpen_radius > 0 or smooth_radius > 0:
719
+ img_norm = smooth_sharpen_img(
720
+ img_norm, sharpen_radius=sharpen_radius, smooth_radius=smooth_radius
721
+ )
722
+
723
+ # Apply tile-based normalization or standard normalization
724
+ if tile_norm_blocksize > 0:
725
+ img_norm = normalize99_tile(
726
+ img_norm,
727
+ blocksize=tile_norm_blocksize,
728
+ lower=percentile[0],
729
+ upper=percentile[1],
730
+ smooth3D=tile_norm_smooth3D,
731
+ norm3D=norm3D,
732
+ )
733
+ cgood[:] = True
734
+ elif normalize:
735
+ if img_norm.ndim == 3 or norm3D: # i.e. if YXC, or ZYXC with norm3D=True
736
+ for c in range(nchan):
737
+ if np.ptp(img_norm[..., c]) > 0.:
738
+ img_norm[..., c] = normalize99(
739
+ img_norm[..., c],
740
+ lower=percentile[0],
741
+ upper=percentile[1],
742
+ copy=False, downsample=True,
743
+ )
744
+ cgood[c] = True
745
+ else: # i.e. if ZYXC with norm3D=False then per Z-slice
746
+ for z in range(img_norm.shape[0]):
747
+ for c in range(nchan):
748
+ if np.ptp(img_norm[z, ..., c]) > 0.:
749
+ img_norm[z, ..., c] = normalize99(
750
+ img_norm[z, ..., c],
751
+ lower=percentile[0],
752
+ upper=percentile[1],
753
+ copy=False, downsample=True,
754
+ )
755
+ cgood[c] = True
756
+
757
+
758
+ if invert:
759
+ if lowhigh is not None or tile_norm_blocksize > 0 or normalize:
760
+ for c in range(nchan):
761
+ if cgood[c]:
762
+ img_norm[..., c] = 1 - img_norm[..., c]
763
+ else:
764
+ error_message = "Cannot invert image without normalization"
765
+ transforms_logger.critical(error_message)
766
+ raise ValueError(error_message)
767
+
768
+ # Move channel axis back to the original position
769
+ if axis != -1 and axis != img_norm.ndim - 1:
770
+ img_norm = np.moveaxis(img_norm, -1, axis)
771
+
772
+ # The transformer can get confused if a channel is all 1's instead of all 0's:
773
+ for i, chan_did_normalize in enumerate(cgood):
774
+ if not chan_did_normalize:
775
+ if img_norm.ndim == 3:
776
+ img_norm[:, :, i] = 0
777
+ if img_norm.ndim == 4:
778
+ img_norm[:, :, :, i] = 0
779
+
780
+ return img_norm
781
+
782
+ def resize_safe(img, Ly, Lx, interpolation=cv2.INTER_LINEAR):
783
+ """OpenCV resize function does not support uint32.
784
+
785
+ This function converts the image to float32 before resizing and then converts it back to uint32. Not safe!
786
+ References issue: https://github.com/MouseLand/cellpose/issues/937
787
+
788
+ Implications:
789
+ * Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not
790
+ a big issue. A 10,000x10,000 image takes 0.47s instead of 0.016s to cast and resize on 32 cores on GPU.
791
+ * Memory: However, memory usage increases. Not tested by how much.
792
+
793
+ Args:
794
+ img (ndarray): Image of size [Ly x Lx].
795
+ Ly (int): Desired height of the resized image.
796
+ Lx (int): Desired width of the resized image.
797
+ interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR.
798
+
799
+ Returns:
800
+ ndarray: Resized image of size [Ly x Lx].
801
+
802
+ """
803
+
804
+ # cast image
805
+ cast = img.dtype == np.uint32
806
+ if cast:
807
+ img = img.astype(np.float32)
808
+
809
+ # resize
810
+ img = cv2.resize(img, (Lx, Ly), interpolation=interpolation)
811
+
812
+ # cast back
813
+ if cast:
814
+ img = img.round().astype(np.uint32)
815
+
816
+ return img
817
+
818
+
819
+ def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEAR,
820
+ no_channels=False):
821
+ """Resize image for computing flows / unresize for computing dynamics.
822
+
823
+ Args:
824
+ img0 (ndarray): Image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X].
825
+ Ly (int, optional): Desired height of the resized image. Defaults to None.
826
+ Lx (int, optional): Desired width of the resized image. Defaults to None.
827
+ rsz (float, optional): Resize coefficient(s) for the image. If Ly is None, rsz is used. Defaults to None.
828
+ interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR.
829
+ no_channels (bool, optional): Flag indicating whether to treat the third dimension as a channel.
830
+ Defaults to False.
831
+
832
+ Returns:
833
+ ndarray: Resized image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan].
834
+
835
+ Raises:
836
+ ValueError: If Ly is None and rsz is None.
837
+
838
+ """
839
+ if Ly is None and rsz is None:
840
+ error_message = "must give size to resize to or factor to use for resizing"
841
+ transforms_logger.critical(error_message)
842
+ raise ValueError(error_message)
843
+
844
+ if Ly is None:
845
+ # determine Ly and Lx using rsz
846
+ if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
847
+ rsz = [rsz, rsz]
848
+ if no_channels:
849
+ Ly = int(img0.shape[-2] * rsz[-2])
850
+ Lx = int(img0.shape[-1] * rsz[-1])
851
+ else:
852
+ Ly = int(img0.shape[-3] * rsz[-2])
853
+ Lx = int(img0.shape[-2] * rsz[-1])
854
+
855
+ # no_channels useful for z-stacks, so the third dimension is not treated as a channel
856
+ # but if this is called for grayscale images, they first become [Ly,Lx,2] so ndim=3 but
857
+ if (img0.ndim > 2 and no_channels) or (img0.ndim == 4 and not no_channels):
858
+ if Ly == 0 or Lx == 0:
859
+ raise ValueError(
860
+ "anisotropy too high / low -- not enough pixels to resize to ratio")
861
+ for i, img in enumerate(img0):
862
+ imgi = resize_safe(img, Ly, Lx, interpolation=interpolation)
863
+ if i==0:
864
+ if no_channels:
865
+ imgs = np.zeros((img0.shape[0], Ly, Lx), imgi.dtype)
866
+ else:
867
+ imgs = np.zeros((img0.shape[0], Ly, Lx, img0.shape[-1]), imgi.dtype)
868
+ imgs[i] = imgi if imgi.ndim > 2 or no_channels else imgi[..., np.newaxis]
869
+ else:
870
+ imgs = resize_safe(img0, Ly, Lx, interpolation=interpolation)
871
+ return imgs
872
+
873
+ def get_pad_yx(Ly, Lx, div=16, extra=1, min_size=None):
874
+ if min_size is None or Ly >= min_size[-2]:
875
+ Lpad = int(div * np.ceil(Ly / div) - Ly)
876
+ else:
877
+ Lpad = min_size[-2] - Ly
878
+ ypad1 = extra * div // 2 + Lpad // 2
879
+ ypad2 = extra * div // 2 + Lpad - Lpad // 2
880
+ if min_size is None or Lx >= min_size[-1]:
881
+ Lpad = int(div * np.ceil(Lx / div) - Lx)
882
+ else:
883
+ Lpad = min_size[-1] - Lx
884
+ xpad1 = extra * div // 2 + Lpad // 2
885
+ xpad2 = extra * div // 2 + Lpad - Lpad // 2
886
+
887
+ return ypad1, ypad2, xpad1, xpad2
888
+
889
+
890
+ def pad_image_ND(img0, div=16, extra=1, min_size=None, zpad=False):
891
+ """Pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D).
892
+
893
+ Args:
894
+ img0 (ndarray): Image of size [nchan (x Lz) x Ly x Lx].
895
+ div (int, optional): Divisor for padding. Defaults to 16.
896
+ extra (int, optional): Extra padding. Defaults to 1.
897
+ min_size (tuple, optional): Minimum size of the image. Defaults to None.
898
+
899
+ Returns:
900
+ A tuple containing (I, ysub, xsub) or (I, ysub, xsub, zsub), I is padded image, -sub are ranges of pixels in the padded image corresponding to img0.
901
+
902
+ """
903
+ Ly, Lx = img0.shape[-2:]
904
+ ypad1, ypad2, xpad1, xpad2 = get_pad_yx(Ly, Lx, div=div, extra=extra, min_size=min_size)
905
+
906
+ if img0.ndim > 3:
907
+ if zpad:
908
+ Lpad = int(div * np.ceil(img0.shape[-3] / div) - img0.shape[-3])
909
+ zpad1 = extra * div // 2 + Lpad // 2
910
+ zpad2 = extra * div // 2 + Lpad - Lpad // 2
911
+ else:
912
+ zpad1, zpad2 = 0, 0
913
+ pads = np.array([[0, 0], [zpad1, zpad2], [ypad1, ypad2], [xpad1, xpad2]])
914
+ else:
915
+ pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]])
916
+
917
+ I = np.pad(img0, pads, mode="constant")
918
+
919
+ ysub = np.arange(ypad1, ypad1 + Ly)
920
+ xsub = np.arange(xpad1, xpad1 + Lx)
921
+ if zpad:
922
+ zsub = np.arange(zpad1, zpad1 + img0.shape[-3])
923
+ return I, ysub, xsub, zsub
924
+ else:
925
+ return I, ysub, xsub
926
+
927
+
928
+ def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=False,
929
+ zcrop=48, do_flip=True, rotate=True, rescale=None, unet=False,
930
+ random_per_image=True):
931
+ """Augmentation by random rotation and resizing.
932
+
933
+ Args:
934
+ X (list of ND-arrays, float): List of image arrays of size [nchan x Ly x Lx] or [Ly x Lx].
935
+ Y (list of ND-arrays, float, optional): List of image labels of size [nlabels x Ly x Lx] or [Ly x Lx].
936
+ The 1st channel of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation).
937
+ If Y.shape[0]==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow].
938
+ If unet, second channel is dist_to_bound. Defaults to None.
939
+ scale_range (float, optional): Range of resizing of images for augmentation.
940
+ Images are resized by (1-scale_range/2) + scale_range * np.random.rand(). Defaults to 1.0.
941
+ xy (tuple, int, optional): Size of transformed images to return. Defaults to (224,224).
942
+ do_flip (bool, optional): Whether or not to flip images horizontally. Defaults to True.
943
+ rotate (bool, optional): Whether or not to rotate images. Defaults to True.
944
+ rescale (array, float, optional): How much to resize images by before performing augmentations. Defaults to None.
945
+ unet (bool, optional): Whether or not to use unet. Defaults to False.
946
+ random_per_image (bool, optional): Different random rotate and resize per image. Defaults to True.
947
+
948
+ Returns:
949
+ A tuple containing (imgi, lbl, scale): imgi (ND-array, float): Transformed images in array [nimg x nchan x xy[0] x xy[1]];
950
+ lbl (ND-array, float): Transformed labels in array [nimg x nchan x xy[0] x xy[1]];
951
+ scale (array, float): Amount each image was resized by.
952
+ """
953
+ scale_range = max(0, min(2, float(scale_range))) if scale_range is not None else scale_range
954
+ nimg = len(X)
955
+ if X[0].ndim > 2:
956
+ nchan = X[0].shape[0]
957
+ else:
958
+ nchan = 1
959
+ if do_3D and X[0].ndim > 3:
960
+ shape = (zcrop, xy[0], xy[1])
961
+ else:
962
+ shape = (xy[0], xy[1])
963
+ imgi = np.zeros((nimg, nchan, *shape), "float32")
964
+
965
+ lbl = []
966
+ if Y is not None:
967
+ if Y[0].ndim > 2:
968
+ nt = Y[0].shape[0]
969
+ else:
970
+ nt = 1
971
+ lbl = np.zeros((nimg, nt, *shape), np.float32)
972
+
973
+ scale = np.ones(nimg, np.float32)
974
+
975
+ for n in range(nimg):
976
+
977
+ if random_per_image or n == 0:
978
+ Ly, Lx = X[n].shape[-2:]
979
+ # generate random augmentation parameters
980
+ flip = np.random.rand() > .5
981
+ theta = np.random.rand() * np.pi * 2 if rotate else 0.
982
+ if scale_range is None:
983
+ scale[n] = 2 ** (4 * np.random.rand() - 2)
984
+ else:
985
+ scale[n] = (1 - scale_range / 2) + scale_range * np.random.rand()
986
+ if rescale is not None:
987
+ scale[n] *= 1. / rescale[n]
988
+ dxy = np.maximum(0, np.array([Lx * scale[n] - xy[1],
989
+ Ly * scale[n] - xy[0]]))
990
+ dxy = (np.random.rand(2,) - .5) * dxy
991
+
992
+ # create affine transform
993
+ cc = np.array([Lx / 2, Ly / 2])
994
+ cc1 = cc - np.array([Lx - xy[1], Ly - xy[0]]) / 2 + dxy
995
+ pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
996
+ pts2 = np.float32([
997
+ cc1,
998
+ cc1 + scale[n] * np.array([np.cos(theta), np.sin(theta)]),
999
+ cc1 + scale[n] *
1000
+ np.array([np.cos(np.pi / 2 + theta),
1001
+ np.sin(np.pi / 2 + theta)])
1002
+ ])
1003
+ M = cv2.getAffineTransform(pts1, pts2)
1004
+
1005
+ img = X[n].copy()
1006
+ if Y is not None:
1007
+ labels = Y[n].copy()
1008
+ if labels.ndim < 3:
1009
+ labels = labels[np.newaxis, :, :]
1010
+
1011
+ if do_3D:
1012
+ Lz = X[n].shape[-3]
1013
+ flip_z = np.random.rand() > .5
1014
+ lz = int(np.round(zcrop / scale[n]))
1015
+ iz = np.random.randint(0, Lz - lz)
1016
+ img = img[:,iz:iz + lz,:,:]
1017
+ if Y is not None:
1018
+ labels = labels[:,iz:iz + lz,:,:]
1019
+
1020
+ if do_flip:
1021
+ if flip:
1022
+ img = img[..., ::-1]
1023
+ if Y is not None:
1024
+ labels = labels[..., ::-1]
1025
+ if nt > 1 and not unet:
1026
+ labels[-1] = -labels[-1]
1027
+ if do_3D and flip_z:
1028
+ img = img[:, ::-1]
1029
+ if Y is not None:
1030
+ labels = labels[:,::-1]
1031
+ if nt > 1 and not unet:
1032
+ labels[-3] = -labels[-3]
1033
+
1034
+ for k in range(nchan):
1035
+ if do_3D:
1036
+ img0 = np.zeros((lz, xy[0], xy[1]), "float32")
1037
+ for z in range(lz):
1038
+ I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]),
1039
+ flags=cv2.INTER_LINEAR)
1040
+ img0[z] = I
1041
+ if scale[n] != 1.0:
1042
+ for y in range(imgi.shape[-2]):
1043
+ imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop),
1044
+ interpolation=cv2.INTER_LINEAR)
1045
+ else:
1046
+ imgi[n, k] = img0
1047
+ else:
1048
+ I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
1049
+ imgi[n, k] = I
1050
+
1051
+ if Y is not None:
1052
+ for k in range(nt):
1053
+ flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
1054
+ if do_3D:
1055
+ lbl0 = np.zeros((lz, xy[0], xy[1]), "float32")
1056
+ for z in range(lz):
1057
+ I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]),
1058
+ flags=flag)
1059
+ lbl0[z] = I
1060
+ if scale[n] != 1.0:
1061
+ for y in range(lbl.shape[-2]):
1062
+ lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop),
1063
+ interpolation=flag)
1064
+ else:
1065
+ lbl[n, k] = lbl0
1066
+ else:
1067
+ lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
1068
+
1069
+ if nt > 1 and not unet:
1070
+ v1 = lbl[n, -1].copy()
1071
+ v2 = lbl[n, -2].copy()
1072
+ lbl[n, -2] = (-v1 * np.sin(-theta) + v2 * np.cos(-theta))
1073
+ lbl[n, -1] = (v1 * np.cos(-theta) + v2 * np.sin(-theta))
1074
+
1075
+ return imgi, lbl, scale
1076
+
1077
+
1078
+ def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=(224, 224), do_3D=False,
1079
+ zcrop=48, do_flip=True, rotate=True, rescale=None, unet=False,
1080
+ random_per_image=True):
1081
+ """Augmentation by random rotation and resizing.
1082
+
1083
+ Args:
1084
+ X (list of ND-arrays, float): List of image arrays of size [nchan x Ly x Lx] or [Ly x Lx].
1085
+ Y (list of ND-arrays, float, optional): List of image labels of size [nlabels x Ly x Lx] or [Ly x Lx].
1086
+ The 1st channel of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation).
1087
+ If Y.shape[0]==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow].
1088
+ If unet, second channel is dist_to_bound. Defaults to None.
1089
+ scale_range (float, optional): Range of resizing of images for augmentation.
1090
+ Images are resized by (1-scale_range/2) + scale_range * np.random.rand(). Defaults to 1.0.
1091
+ xy (tuple, int, optional): Size of transformed images to return. Defaults to (224,224).
1092
+ do_flip (bool, optional): Whether or not to flip images horizontally. Defaults to True.
1093
+ rotate (bool, optional): Whether or not to rotate images. Defaults to True.
1094
+ rescale (array, float, optional): How much to resize images by before performing augmentations. Defaults to None.
1095
+ unet (bool, optional): Whether or not to use unet. Defaults to False.
1096
+ random_per_image (bool, optional): Different random rotate and resize per image. Defaults to True.
1097
+
1098
+ Returns:
1099
+ A tuple containing (imgi, lbl, scale): imgi (ND-array, float): Transformed images in array [nimg x nchan x xy[0] x xy[1]];
1100
+ lbl (ND-array, float): Transformed labels in array [nimg x nchan x xy[0] x xy[1]];
1101
+ scale (array, float): Amount each image was resized by.
1102
+ """
1103
+ scale_range = max(0, min(2, float(scale_range))) if scale_range is not None else scale_range
1104
+ nimg = len(X)
1105
+ if X[0].ndim > 2:
1106
+ nchan = X[0].shape[0]
1107
+ else:
1108
+ nchan = 1
1109
+ if do_3D and X[0].ndim > 3:
1110
+ shape = (zcrop, xy[0], xy[1])
1111
+ else:
1112
+ shape = (xy[0], xy[1])
1113
+ imgi = np.zeros((nimg, nchan, *shape), "float32")
1114
+
1115
+ lbl = []
1116
+ if Y is not None:
1117
+ if Y[0].ndim > 2:
1118
+ nt = Y[0].shape[0]
1119
+ else:
1120
+ nt = 1
1121
+ lbl = np.zeros((nimg, nt, *shape), np.float32)
1122
+
1123
+ if feat is not None:
1124
+ if feat[0].ndim > 2:
1125
+ nf = feat[0].shape[0]
1126
+ else:
1127
+ nf = 1
1128
+ feat_out = np.zeros((nimg, nf, *shape), "float32")
1129
+
1130
+ scale = np.ones(nimg, np.float32)
1131
+
1132
+ for n in range(nimg):
1133
+
1134
+ if random_per_image or n == 0:
1135
+ Ly, Lx = X[n].shape[-2:]
1136
+ # generate random augmentation parameters
1137
+ flip = np.random.rand() > .5
1138
+ theta = np.random.rand() * np.pi * 2 if rotate else 0.
1139
+ if scale_range is None:
1140
+ scale[n] = 2 ** (4 * np.random.rand() - 2)
1141
+ else:
1142
+ scale[n] = (1 - scale_range / 2) + scale_range * np.random.rand()
1143
+ if rescale is not None:
1144
+ scale[n] *= 1. / rescale[n]
1145
+ dxy = np.maximum(0, np.array([Lx * scale[n] - xy[1],
1146
+ Ly * scale[n] - xy[0]]))
1147
+ dxy = (np.random.rand(2,) - .5) * dxy
1148
+
1149
+ # create affine transform
1150
+ cc = np.array([Lx / 2, Ly / 2])
1151
+ cc1 = cc - np.array([Lx - xy[1], Ly - xy[0]]) / 2 + dxy
1152
+ pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
1153
+ pts2 = np.float32([
1154
+ cc1,
1155
+ cc1 + scale[n] * np.array([np.cos(theta), np.sin(theta)]),
1156
+ cc1 + scale[n] *
1157
+ np.array([np.cos(np.pi / 2 + theta),
1158
+ np.sin(np.pi / 2 + theta)])
1159
+ ])
1160
+ M = cv2.getAffineTransform(pts1, pts2)
1161
+
1162
+ img = X[n].copy()
1163
+ if Y is not None:
1164
+ labels = Y[n].copy()
1165
+ if labels.ndim < 3:
1166
+ labels = labels[np.newaxis, :, :]
1167
+ if feat is not None:
1168
+ feats = feat[n].copy()
1169
+ if feats.ndim < 3:
1170
+ feats = feats[np.newaxis, :, :]
1171
+
1172
+ if do_3D:
1173
+ Lz = X[n].shape[-3]
1174
+ flip_z = np.random.rand() > .5
1175
+ lz = int(np.round(zcrop / scale[n]))
1176
+ iz = np.random.randint(0, Lz - lz)
1177
+ img = img[:,iz:iz + lz,:,:]
1178
+ if Y is not None:
1179
+ labels = labels[:,iz:iz + lz,:,:]
1180
+ if feat is not None:
1181
+ feats = feats[:,iz:iz + lz,:,:]
1182
+
1183
+ if do_flip:
1184
+ if flip:
1185
+ img = img[..., ::-1]
1186
+ if Y is not None:
1187
+ labels = labels[..., ::-1]
1188
+ if nt > 1 and not unet:
1189
+ labels[-1] = -labels[-1]
1190
+ if feat is not None:
1191
+ feats = feats[..., ::-1]
1192
+ if do_3D and flip_z:
1193
+ img = img[:, ::-1]
1194
+ if Y is not None:
1195
+ labels = labels[:,::-1]
1196
+ if nt > 1 and not unet:
1197
+ labels[-3] = -labels[-3]
1198
+ if feat is not None:
1199
+ feats = feats[:, ::-1]
1200
+
1201
+ for k in range(nchan):
1202
+ if do_3D:
1203
+ img0 = np.zeros((lz, xy[0], xy[1]), "float32")
1204
+ for z in range(lz):
1205
+ I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]),
1206
+ flags=cv2.INTER_LINEAR)
1207
+ img0[z] = I
1208
+ if scale[n] != 1.0:
1209
+ for y in range(imgi.shape[-2]):
1210
+ imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop),
1211
+ interpolation=cv2.INTER_LINEAR)
1212
+ else:
1213
+ imgi[n, k] = img0
1214
+ else:
1215
+ I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
1216
+ imgi[n, k] = I
1217
+
1218
+ if Y is not None:
1219
+ for k in range(nt):
1220
+ flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
1221
+ if do_3D:
1222
+ lbl0 = np.zeros((lz, xy[0], xy[1]), "float32")
1223
+ for z in range(lz):
1224
+ I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]),
1225
+ flags=flag)
1226
+ lbl0[z] = I
1227
+ if scale[n] != 1.0:
1228
+ for y in range(lbl.shape[-2]):
1229
+ lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop),
1230
+ interpolation=flag)
1231
+ else:
1232
+ lbl[n, k] = lbl0
1233
+ else:
1234
+ lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
1235
+
1236
+ if nt > 1 and not unet:
1237
+ v1 = lbl[n, -1].copy()
1238
+ v2 = lbl[n, -2].copy()
1239
+ lbl[n, -2] = (-v1 * np.sin(-theta) + v2 * np.cos(-theta))
1240
+ lbl[n, -1] = (v1 * np.cos(-theta) + v2 * np.sin(-theta))
1241
+
1242
+ if feat is not None:
1243
+ for k in range(nf):
1244
+ if do_3D:
1245
+ feat0 = np.zeros((lz, xy[0], xy[1]), "float32")
1246
+ for z in range(lz):
1247
+ I = cv2.warpAffine(feats[k, z], M, (xy[1], xy[0]),
1248
+ flags=cv2.INTER_LINEAR)
1249
+ feat0[z] = I
1250
+ if scale[n] != 1.0:
1251
+ for y in range(feat_out.shape[-2]):
1252
+ feat_out[n, k, :, y] = cv2.resize(feat0[:, y], (xy[1], zcrop),
1253
+ interpolation=cv2.INTER_LINEAR)
1254
+ else:
1255
+ feat_out[n, k] = feat0
1256
+ else:
1257
+ feat_out[n, k] = cv2.warpAffine(feats[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
1258
+
1259
+
1260
+
1261
+ return imgi, lbl, feat_out, scale
models/seg_post_model/cellpose/utils.py ADDED
@@ -0,0 +1,667 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import logging
5
+ import os, tempfile, shutil, io
6
+ from tqdm import tqdm, trange
7
+ from urllib.request import urlopen
8
+ import cv2
9
+ from scipy.ndimage import find_objects, gaussian_filter, generate_binary_structure, label
10
+ from scipy.spatial import ConvexHull
11
+ import numpy as np
12
+ import colorsys
13
+ import fastremap
14
+ import fill_voids
15
+ from multiprocessing import Pool, cpu_count
16
+ # try:
17
+ # from cellpose import metrics
18
+ # except:
19
+ # import metrics as metrics
20
+ from models.seg_post_model.cellpose import metrics
21
+
22
+ try:
23
+ from skimage.morphology import remove_small_holes
24
+ SKIMAGE_ENABLED = True
25
+ except:
26
+ SKIMAGE_ENABLED = False
27
+
28
+
29
+ class TqdmToLogger(io.StringIO):
30
+ """
31
+ Output stream for TQDM which will output to logger module instead of
32
+ the StdOut.
33
+ """
34
+ logger = None
35
+ level = None
36
+ buf = ""
37
+
38
+ def __init__(self, logger, level=None):
39
+ super(TqdmToLogger, self).__init__()
40
+ self.logger = logger
41
+ self.level = level or logging.INFO
42
+
43
+ def write(self, buf):
44
+ self.buf = buf.strip("\r\n\t ")
45
+
46
+ def flush(self):
47
+ self.logger.log(self.level, self.buf)
48
+
49
+
50
+ def rgb_to_hsv(arr):
51
+ rgb_to_hsv_channels = np.vectorize(colorsys.rgb_to_hsv)
52
+ r, g, b = np.rollaxis(arr, axis=-1)
53
+ h, s, v = rgb_to_hsv_channels(r, g, b)
54
+ hsv = np.stack((h, s, v), axis=-1)
55
+ return hsv
56
+
57
+
58
+ def hsv_to_rgb(arr):
59
+ hsv_to_rgb_channels = np.vectorize(colorsys.hsv_to_rgb)
60
+ h, s, v = np.rollaxis(arr, axis=-1)
61
+ r, g, b = hsv_to_rgb_channels(h, s, v)
62
+ rgb = np.stack((r, g, b), axis=-1)
63
+ return rgb
64
+
65
+
66
+ def download_url_to_file(url, dst, progress=True):
67
+ r"""Download object at the given URL to a local path.
68
+ Thanks to torch, slightly modified
69
+ Args:
70
+ url (string): URL of the object to download
71
+ dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
72
+ progress (bool, optional): whether or not to display a progress bar to stderr
73
+ Default: True
74
+ """
75
+ file_size = None
76
+ import ssl
77
+ ssl._create_default_https_context = ssl._create_unverified_context
78
+ u = urlopen(url)
79
+ meta = u.info()
80
+ if hasattr(meta, "getheaders"):
81
+ content_length = meta.getheaders("Content-Length")
82
+ else:
83
+ content_length = meta.get_all("Content-Length")
84
+ if content_length is not None and len(content_length) > 0:
85
+ file_size = int(content_length[0])
86
+ # We deliberately save it in a temp file and move it after
87
+ dst = os.path.expanduser(dst)
88
+ dst_dir = os.path.dirname(dst)
89
+ f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
90
+ try:
91
+ with tqdm(total=file_size, disable=not progress, unit="B", unit_scale=True,
92
+ unit_divisor=1024) as pbar:
93
+ while True:
94
+ buffer = u.read(8192)
95
+ if len(buffer) == 0:
96
+ break
97
+ f.write(buffer)
98
+ pbar.update(len(buffer))
99
+ f.close()
100
+ shutil.move(f.name, dst)
101
+ finally:
102
+ f.close()
103
+ if os.path.exists(f.name):
104
+ os.remove(f.name)
105
+
106
+
107
+ def distance_to_boundary(masks):
108
+ """Get the distance to the boundary of mask pixels.
109
+
110
+ Args:
111
+ masks (int, 2D or 3D array): The masks array. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
112
+
113
+ Returns:
114
+ dist_to_bound (2D or 3D array): The distance to the boundary. Size [Ly x Lx] or [Lz x Ly x Lx].
115
+
116
+ Raises:
117
+ ValueError: If the masks array is not 2D or 3D.
118
+
119
+ """
120
+ if masks.ndim > 3 or masks.ndim < 2:
121
+ raise ValueError("distance_to_boundary takes 2D or 3D array, not %dD array" %
122
+ masks.ndim)
123
+ dist_to_bound = np.zeros(masks.shape, np.float64)
124
+
125
+ if masks.ndim == 3:
126
+ for i in range(masks.shape[0]):
127
+ dist_to_bound[i] = distance_to_boundary(masks[i])
128
+ return dist_to_bound
129
+ else:
130
+ slices = find_objects(masks)
131
+ for i, si in enumerate(slices):
132
+ if si is not None:
133
+ sr, sc = si
134
+ mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
135
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
136
+ cv2.CHAIN_APPROX_NONE)
137
+ pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
138
+ ypix, xpix = np.nonzero(mask)
139
+ min_dist = ((ypix[:, np.newaxis] - pvr)**2 +
140
+ (xpix[:, np.newaxis] - pvc)**2).min(axis=1)
141
+ dist_to_bound[ypix + sr.start, xpix + sc.start] = min_dist
142
+ return dist_to_bound
143
+
144
+
145
+ def masks_to_edges(masks, threshold=1.0):
146
+ """Get edges of masks as a 0-1 array.
147
+
148
+ Args:
149
+ masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
150
+ threshold (float, optional): Threshold value for distance to boundary. Defaults to 1.0.
151
+
152
+ Returns:
153
+ edges (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are edge pixels.
154
+ """
155
+ dist_to_bound = distance_to_boundary(masks)
156
+ edges = (dist_to_bound < threshold) * (masks > 0)
157
+ return edges
158
+
159
+
160
+ def remove_edge_masks(masks, change_index=True):
161
+ """Removes masks with pixels on the edge of the image.
162
+
163
+ Args:
164
+ masks (int, 2D or 3D array): The masks to be processed. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
165
+ change_index (bool, optional): If True, after removing masks, changes the indexing so that there are no missing label numbers. Defaults to True.
166
+
167
+ Returns:
168
+ outlines (2D or 3D array): The processed masks. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
169
+ """
170
+ slices = find_objects(masks.astype(int))
171
+ for i, si in enumerate(slices):
172
+ remove = False
173
+ if si is not None:
174
+ for d, sid in enumerate(si):
175
+ if sid.start == 0 or sid.stop == masks.shape[d]:
176
+ remove = True
177
+ break
178
+ if remove:
179
+ masks[si][masks[si] == i + 1] = 0
180
+ shape = masks.shape
181
+ if change_index:
182
+ _, masks = np.unique(masks, return_inverse=True)
183
+ masks = np.reshape(masks, shape).astype(np.int32)
184
+
185
+ return masks
186
+
187
+
188
+ def masks_to_outlines(masks):
189
+ """Get outlines of masks as a 0-1 array.
190
+
191
+ Args:
192
+ masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
193
+
194
+ Returns:
195
+ outlines (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are outlines.
196
+ """
197
+ if masks.ndim > 3 or masks.ndim < 2:
198
+ raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
199
+ masks.ndim)
200
+ outlines = np.zeros(masks.shape, bool)
201
+
202
+ if masks.ndim == 3:
203
+ for i in range(masks.shape[0]):
204
+ outlines[i] = masks_to_outlines(masks[i])
205
+ return outlines
206
+ else:
207
+ slices = find_objects(masks.astype(int))
208
+ for i, si in enumerate(slices):
209
+ if si is not None:
210
+ sr, sc = si
211
+ mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
212
+ contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
213
+ cv2.CHAIN_APPROX_NONE)
214
+ pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
215
+ vr, vc = pvr + sr.start, pvc + sc.start
216
+ outlines[vr, vc] = 1
217
+ return outlines
218
+
219
+
220
+ def outlines_list(masks, multiprocessing_threshold=1000, multiprocessing=None):
221
+ """Get outlines of masks as a list to loop over for plotting.
222
+
223
+ Args:
224
+ masks (ndarray): Array of masks.
225
+ multiprocessing_threshold (int, optional): Threshold for enabling multiprocessing. Defaults to 1000.
226
+ multiprocessing (bool, optional): Flag to enable multiprocessing. Defaults to None.
227
+
228
+ Returns:
229
+ list: List of outlines.
230
+
231
+ Raises:
232
+ None
233
+
234
+ Notes:
235
+ - This function is a wrapper for outlines_list_single and outlines_list_multi.
236
+ - Multiprocessing is disabled for Windows.
237
+ """
238
+ # default to use multiprocessing if not few_masks, but allow user to override
239
+ if multiprocessing is None:
240
+ few_masks = np.max(masks) < multiprocessing_threshold
241
+ multiprocessing = not few_masks
242
+
243
+ # disable multiprocessing for Windows
244
+ if os.name == "nt":
245
+ if multiprocessing:
246
+ logging.getLogger(__name__).warning(
247
+ "Multiprocessing is disabled for Windows")
248
+ multiprocessing = False
249
+
250
+ if multiprocessing:
251
+ return outlines_list_multi(masks)
252
+ else:
253
+ return outlines_list_single(masks)
254
+
255
+
256
+ def outlines_list_single(masks):
257
+ """Get outlines of masks as a list to loop over for plotting.
258
+
259
+ Args:
260
+ masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
261
+
262
+ Returns:
263
+ list: List of outlines as pixel coordinates.
264
+
265
+ """
266
+ outpix = []
267
+ for n in np.unique(masks)[1:]:
268
+ mn = masks == n
269
+ if mn.sum() > 0:
270
+ contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
271
+ method=cv2.CHAIN_APPROX_NONE)
272
+ contours = contours[-2]
273
+ cmax = np.argmax([c.shape[0] for c in contours])
274
+ pix = contours[cmax].astype(int).squeeze()
275
+ if len(pix) > 4:
276
+ outpix.append(pix)
277
+ else:
278
+ outpix.append(np.zeros((0, 2)))
279
+ return outpix
280
+
281
+
282
+ def outlines_list_multi(masks, num_processes=None):
283
+ """
284
+ Get outlines of masks as a list to loop over for plotting.
285
+
286
+ Args:
287
+ masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
288
+
289
+ Returns:
290
+ list: List of outlines as pixel coordinates.
291
+ """
292
+ if num_processes is None:
293
+ num_processes = cpu_count()
294
+
295
+ unique_masks = np.unique(masks)[1:]
296
+ with Pool(processes=num_processes) as pool:
297
+ outpix = pool.map(get_outline_multi, [(masks, n) for n in unique_masks])
298
+ return outpix
299
+
300
+
301
+ def get_outline_multi(args):
302
+ """Get the outline of a specific mask in a multi-mask image.
303
+
304
+ Args:
305
+ args (tuple): A tuple containing the masks and the mask number.
306
+
307
+ Returns:
308
+ numpy.ndarray: The outline of the specified mask as an array of coordinates.
309
+
310
+ """
311
+ masks, n = args
312
+ mn = masks == n
313
+ if mn.sum() > 0:
314
+ contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
315
+ method=cv2.CHAIN_APPROX_NONE)
316
+ contours = contours[-2]
317
+ cmax = np.argmax([c.shape[0] for c in contours])
318
+ pix = contours[cmax].astype(int).squeeze()
319
+ return pix if len(pix) > 4 else np.zeros((0, 2))
320
+ return np.zeros((0, 2))
321
+
322
+
323
+ def dilate_masks(masks, n_iter=5):
324
+ """Dilate masks by n_iter pixels.
325
+
326
+ Args:
327
+ masks (ndarray): Array of masks.
328
+ n_iter (int, optional): Number of pixels to dilate the masks. Defaults to 5.
329
+
330
+ Returns:
331
+ ndarray: Dilated masks.
332
+ """
333
+ dilated_masks = masks.copy()
334
+ for n in range(n_iter):
335
+ # define the structuring element to use for dilation
336
+ kernel = np.ones((3, 3), "uint8")
337
+ # find the distance to each mask (distances are zero within masks)
338
+ dist_transform = cv2.distanceTransform((dilated_masks == 0).astype("uint8"),
339
+ cv2.DIST_L2, 5)
340
+ # dilate each mask and assign to it the pixels along the border of the mask
341
+ # (does not allow dilation into other masks since dist_transform is zero there)
342
+ for i in range(1, np.max(masks) + 1):
343
+ mask = (dilated_masks == i).astype("uint8")
344
+ dilated_mask = cv2.dilate(mask, kernel, iterations=1)
345
+ dilated_mask = np.logical_and(dist_transform < 2, dilated_mask)
346
+ dilated_masks[dilated_mask > 0] = i
347
+ return dilated_masks
348
+
349
+
350
+ def get_perimeter(points):
351
+ """
352
+ Calculate the perimeter of a set of points.
353
+
354
+ Parameters:
355
+ points (ndarray): An array of points with shape (npoints, ndim).
356
+
357
+ Returns:
358
+ float: The perimeter of the points.
359
+
360
+ """
361
+ if points.shape[0] > 4:
362
+ points = np.append(points, points[:1], axis=0)
363
+ return ((np.diff(points, axis=0)**2).sum(axis=1)**0.5).sum()
364
+ else:
365
+ return 0
366
+
367
+
368
+ def get_mask_compactness(masks):
369
+ """
370
+ Calculate the compactness of masks.
371
+
372
+ Parameters:
373
+ masks (ndarray): Binary masks representing objects.
374
+
375
+ Returns:
376
+ ndarray: Array of compactness values for each mask.
377
+ """
378
+ perimeters = get_mask_perimeters(masks)
379
+ npoints = np.unique(masks, return_counts=True)[1][1:]
380
+ areas = npoints
381
+ compactness = 4 * np.pi * areas / perimeters**2
382
+ compactness[perimeters == 0] = 0
383
+ compactness[compactness > 1.0] = 1.0
384
+ return compactness
385
+
386
+
387
+ def get_mask_perimeters(masks):
388
+ """
389
+ Calculate the perimeters of the given masks.
390
+
391
+ Parameters:
392
+ masks (numpy.ndarray): Binary masks representing objects.
393
+
394
+ Returns:
395
+ numpy.ndarray: Array containing the perimeters of each mask.
396
+ """
397
+ perimeters = np.zeros(masks.max())
398
+ for n in range(masks.max()):
399
+ mn = masks == (n + 1)
400
+ if mn.sum() > 0:
401
+ contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
402
+ method=cv2.CHAIN_APPROX_NONE)[-2]
403
+ perimeters[n] = np.array(
404
+ [get_perimeter(c.astype(int).squeeze()) for c in contours]).sum()
405
+
406
+ return perimeters
407
+
408
+
409
+ def circleMask(d0):
410
+ """
411
+ Creates an array with indices which are the radius of that x,y point.
412
+
413
+ Args:
414
+ d0 (tuple): Patch of (-d0, d0+1) over which radius is computed.
415
+
416
+ Returns:
417
+ tuple: A tuple containing:
418
+ - rs (ndarray): Array of radii with shape (2*d0[0]+1, 2*d0[1]+1).
419
+ - dx (ndarray): Indices of the patch along the x-axis.
420
+ - dy (ndarray): Indices of the patch along the y-axis.
421
+ """
422
+ dx = np.tile(np.arange(-d0[1], d0[1] + 1), (2 * d0[0] + 1, 1))
423
+ dy = np.tile(np.arange(-d0[0], d0[0] + 1), (2 * d0[1] + 1, 1))
424
+ dy = dy.transpose()
425
+
426
+ rs = (dy**2 + dx**2)**0.5
427
+ return rs, dx, dy
428
+
429
+
430
+ def get_mask_stats(masks_true):
431
+ """
432
+ Calculate various statistics for the given binary masks.
433
+
434
+ Parameters:
435
+ masks_true (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
436
+
437
+ Returns:
438
+ convexity (ndarray): Convexity values for each mask.
439
+ solidity (ndarray): Solidity values for each mask.
440
+ compactness (ndarray): Compactness values for each mask.
441
+ """
442
+ mask_perimeters = get_mask_perimeters(masks_true)
443
+
444
+ # disk for compactness
445
+ rs, dy, dx = circleMask(np.array([100, 100]))
446
+ rsort = np.sort(rs.flatten())
447
+
448
+ # area for solidity
449
+ npoints = np.unique(masks_true, return_counts=True)[1][1:]
450
+ areas = npoints - mask_perimeters / 2 - 1
451
+
452
+ compactness = np.zeros(masks_true.max())
453
+ convexity = np.zeros(masks_true.max())
454
+ solidity = np.zeros(masks_true.max())
455
+ convex_perimeters = np.zeros(masks_true.max())
456
+ convex_areas = np.zeros(masks_true.max())
457
+ for ic in range(masks_true.max()):
458
+ points = np.array(np.nonzero(masks_true == (ic + 1))).T
459
+ if len(points) > 15 and mask_perimeters[ic] > 0:
460
+ med = np.median(points, axis=0)
461
+ # compute compactness of ROI
462
+ r2 = ((points - med)**2).sum(axis=1)**0.5
463
+ compactness[ic] = (rsort[:r2.size].mean() + 1e-10) / r2.mean()
464
+ try:
465
+ hull = ConvexHull(points)
466
+ convex_perimeters[ic] = hull.area
467
+ convex_areas[ic] = hull.volume
468
+ except:
469
+ convex_perimeters[ic] = 0
470
+
471
+ convexity[mask_perimeters > 0.0] = (convex_perimeters[mask_perimeters > 0.0] /
472
+ mask_perimeters[mask_perimeters > 0.0])
473
+ solidity[convex_areas > 0.0] = (areas[convex_areas > 0.0] /
474
+ convex_areas[convex_areas > 0.0])
475
+ convexity = np.clip(convexity, 0.0, 1.0)
476
+ solidity = np.clip(solidity, 0.0, 1.0)
477
+ compactness = np.clip(compactness, 0.0, 1.0)
478
+ return convexity, solidity, compactness
479
+
480
+
481
+ def get_masks_unet(output, cell_threshold=0, boundary_threshold=0):
482
+ """Create masks using cell probability and cell boundary.
483
+
484
+ Args:
485
+ output (ndarray): The output array containing cell probability and cell boundary.
486
+ cell_threshold (float, optional): The threshold value for cell probability. Defaults to 0.
487
+ boundary_threshold (float, optional): The threshold value for cell boundary. Defaults to 0.
488
+
489
+ Returns:
490
+ ndarray: The masks representing the segmented cells.
491
+
492
+ """
493
+ cells = (output[..., 1] - output[..., 0]) > cell_threshold
494
+ selem = generate_binary_structure(cells.ndim, connectivity=1)
495
+ labels, nlabels = label(cells, selem)
496
+
497
+ if output.shape[-1] > 2:
498
+ slices = find_objects(labels)
499
+ dists = 10000 * np.ones(labels.shape, np.float32)
500
+ mins = np.zeros(labels.shape, np.int32)
501
+ borders = np.logical_and(~(labels > 0), output[..., 2] > boundary_threshold)
502
+ pad = 10
503
+ for i, slc in enumerate(slices):
504
+ if slc is not None:
505
+ slc_pad = tuple([
506
+ slice(max(0, sli.start - pad), min(labels.shape[j], sli.stop + pad))
507
+ for j, sli in enumerate(slc)
508
+ ])
509
+ msk = (labels[slc_pad] == (i + 1)).astype(np.float32)
510
+ msk = 1 - gaussian_filter(msk, 5)
511
+ dists[slc_pad] = np.minimum(dists[slc_pad], msk)
512
+ mins[slc_pad][dists[slc_pad] == msk] = (i + 1)
513
+ labels[labels == 0] = borders[labels == 0] * mins[labels == 0]
514
+
515
+ masks = labels
516
+ shape0 = masks.shape
517
+ _, masks = np.unique(masks, return_inverse=True)
518
+ masks = np.reshape(masks, shape0)
519
+ return masks
520
+
521
+
522
+ def stitch3D(masks, stitch_threshold=0.25):
523
+ """
524
+ Stitch 2D masks into a 3D volume using a stitch_threshold on IOU.
525
+
526
+ Args:
527
+ masks (list or ndarray): List of 2D masks.
528
+ stitch_threshold (float, optional): Threshold value for stitching. Defaults to 0.25.
529
+
530
+ Returns:
531
+ list: List of stitched 3D masks.
532
+ """
533
+ mmax = masks[0].max()
534
+ empty = 0
535
+ for i in trange(len(masks) - 1):
536
+ iou = metrics._intersection_over_union(masks[i + 1], masks[i])[1:, 1:]
537
+ if not iou.size and empty == 0:
538
+ masks[i + 1] = masks[i + 1]
539
+ mmax = masks[i + 1].max()
540
+ elif not iou.size and not empty == 0:
541
+ icount = masks[i + 1].max()
542
+ istitch = np.arange(mmax + 1, mmax + icount + 1, 1, masks.dtype)
543
+ mmax += icount
544
+ istitch = np.append(np.array(0), istitch)
545
+ masks[i + 1] = istitch[masks[i + 1]]
546
+ else:
547
+ iou[iou < stitch_threshold] = 0.0
548
+ iou[iou < iou.max(axis=0)] = 0.0
549
+ istitch = iou.argmax(axis=1) + 1
550
+ ino = np.nonzero(iou.max(axis=1) == 0.0)[0]
551
+ istitch[ino] = np.arange(mmax + 1, mmax + len(ino) + 1, 1, masks.dtype)
552
+ mmax += len(ino)
553
+ istitch = np.append(np.array(0), istitch)
554
+ masks[i + 1] = istitch[masks[i + 1]]
555
+ empty = 1
556
+
557
+ return masks
558
+
559
+
560
+ def diameters(masks):
561
+ """
562
+ Calculate the diameters of the objects in the given masks.
563
+
564
+ Parameters:
565
+ masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
566
+
567
+ Returns:
568
+ tuple: A tuple containing the median diameter and an array of diameters for each object.
569
+
570
+ Examples:
571
+ >>> masks = np.array([[0, 1, 1], [1, 0, 0], [1, 1, 0]])
572
+ >>> diameters(masks)
573
+ (1.0, array([1.41421356, 1.0, 1.0]))
574
+ """
575
+ uniq, counts = fastremap.unique(masks.astype("int32"), return_counts=True)
576
+ counts = counts[1:]
577
+ md = np.median(counts**0.5)
578
+ if np.isnan(md):
579
+ md = 0
580
+ md /= (np.pi**0.5) / 2
581
+ return md, counts**0.5
582
+
583
+
584
+ def radius_distribution(masks, bins):
585
+ """
586
+ Calculate the radius distribution of masks.
587
+
588
+ Args:
589
+ masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
590
+ bins (int): Number of bins for the histogram.
591
+
592
+ Returns:
593
+ A tuple containing a normalized histogram of radii, median radius, array of radii.
594
+
595
+ """
596
+ unique, counts = np.unique(masks, return_counts=True)
597
+ counts = counts[unique != 0]
598
+ nb, _ = np.histogram((counts**0.5) * 0.5, bins)
599
+ nb = nb.astype(np.float32)
600
+ if nb.sum() > 0:
601
+ nb = nb / nb.sum()
602
+ md = np.median(counts**0.5) * 0.5
603
+ if np.isnan(md):
604
+ md = 0
605
+ md /= (np.pi**0.5) / 2
606
+ return nb, md, (counts**0.5) / 2
607
+
608
+
609
+ def size_distribution(masks):
610
+ """
611
+ Calculates the size distribution of masks.
612
+
613
+ Args:
614
+ masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
615
+
616
+ Returns:
617
+ float: The ratio of the 25th percentile of mask sizes to the 75th percentile of mask sizes.
618
+ """
619
+ counts = np.unique(masks, return_counts=True)[1][1:]
620
+ return np.percentile(counts, 25) / np.percentile(counts, 75)
621
+
622
+
623
+ def fill_holes_and_remove_small_masks(masks, min_size=15):
624
+ """ Fills holes in masks (2D/3D) and discards masks smaller than min_size.
625
+
626
+ This function fills holes in each mask using fill_voids.fill.
627
+ It also removes masks that are smaller than the specified min_size.
628
+
629
+ Parameters:
630
+ masks (ndarray): Int, 2D or 3D array of labelled masks.
631
+ 0 represents no mask, while positive integers represent mask labels.
632
+ The size can be [Ly x Lx] or [Lz x Ly x Lx].
633
+ min_size (int, optional): Minimum number of pixels per mask.
634
+ Masks smaller than min_size will be removed.
635
+ Set to -1 to turn off this functionality. Default is 15.
636
+
637
+ Returns:
638
+ ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed.
639
+ 0 represents no mask, while positive integers represent mask labels.
640
+ The size is [Ly x Lx] or [Lz x Ly x Lx].
641
+ """
642
+
643
+ if masks.ndim > 3 or masks.ndim < 2:
644
+ raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
645
+ masks.ndim)
646
+
647
+ # Filter small masks
648
+ if min_size > 0:
649
+ counts = fastremap.unique(masks, return_counts=True)[1][1:]
650
+ masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
651
+ fastremap.renumber(masks, in_place=True)
652
+
653
+ slices = find_objects(masks)
654
+ j = 0
655
+ for i, slc in enumerate(slices):
656
+ if slc is not None:
657
+ msk = masks[slc] == (i + 1)
658
+ msk = fill_voids.fill(msk)
659
+ masks[slc][msk] = (j + 1)
660
+ j += 1
661
+
662
+ if min_size > 0:
663
+ counts = fastremap.unique(masks, return_counts=True)[1][1:]
664
+ masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
665
+ fastremap.renumber(masks, in_place=True)
666
+
667
+ return masks
models/seg_post_model/cellpose/version.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ from importlib.metadata import PackageNotFoundError, version
5
+ import sys
6
+ from platform import python_version
7
+ import torch
8
+
9
+ try:
10
+ version = version("cellpose")
11
+ except PackageNotFoundError:
12
+ version = "unknown"
13
+
14
+ version_str = f"""
15
+ cellpose version: \t{version}
16
+ platform: \t{sys.platform}
17
+ python version: \t{python_version()}
18
+ torch version: \t{torch.__version__}"""
models/seg_post_model/cellpose/vit_sam.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
3
+ """
4
+
5
+ import torch
6
+ from segment_anything import sam_model_registry
7
+ torch.backends.cuda.matmul.allow_tf32 = True
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+ class Transformer(nn.Module):
12
+ def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4,
13
+ checkpoint=None, dtype=torch.float32):
14
+ super(Transformer, self).__init__()
15
+ """
16
+ print(self.encoder.patch_embed)
17
+ PatchEmbed(
18
+ (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
19
+ )
20
+ print(self.encoder.neck)
21
+ Sequential(
22
+ (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
23
+ (1): LayerNorm2d()
24
+ (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
25
+ (3): LayerNorm2d()
26
+ )
27
+ """
28
+ # instantiate the vit model, default to not loading SAM
29
+ # checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM
30
+ self.encoder = sam_model_registry[backbone](checkpoint).image_encoder
31
+ w = self.encoder.patch_embed.proj.weight.detach()
32
+ nchan = w.shape[0]
33
+
34
+ # change token size to ps x ps
35
+ self.ps = ps
36
+ self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
37
+ self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps]
38
+
39
+ # adjust position embeddings for new bsize and new token size
40
+ ds = (1024 // 16) // (bsize // ps)
41
+ self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True)
42
+
43
+ # readout weights for nout output channels
44
+ # if nout is changed, weights will not load correctly from pretrained Cellpose-SAM
45
+ self.nout = nout
46
+ self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1)
47
+
48
+ # W2 reshapes token space to pixel space, not trainable
49
+ self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps),
50
+ requires_grad=False)
51
+
52
+ # fraction of layers to drop at random during training
53
+ self.rdrop = rdrop
54
+
55
+ # average diameter of ROIs from training images from fine-tuning
56
+ self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False)
57
+ # average diameter of ROIs during main training
58
+ self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False)
59
+
60
+ # set attention to global in every layer
61
+ for blk in self.encoder.blocks:
62
+ blk.window_size = 0
63
+
64
+ self.dtype = dtype
65
+
66
+ def forward(self, x, feat=None):
67
+ # same progression as SAM until readout
68
+ x = self.encoder.patch_embed(x)
69
+ if feat is not None:
70
+ feat = self.encoder.patch_embed(feat)
71
+ x = x + x * feat * 0.5
72
+
73
+ if self.encoder.pos_embed is not None:
74
+ x = x + self.encoder.pos_embed
75
+
76
+ if self.training and self.rdrop > 0:
77
+ nlay = len(self.encoder.blocks)
78
+ rdrop = (torch.rand((len(x), nlay), device=x.device) <
79
+ torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
80
+ for i, blk in enumerate(self.encoder.blocks):
81
+ mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
82
+ x = x * mask + blk(x) * (1-mask)
83
+ else:
84
+ for blk in self.encoder.blocks:
85
+ x = blk(x)
86
+
87
+ x = self.encoder.neck(x.permute(0, 3, 1, 2))
88
+
89
+ # readout is changed here
90
+ x1 = self.out(x)
91
+ x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
92
+
93
+ # maintain the second output of feature size 256 for backwards compatibility
94
+
95
+ return x1, torch.randn((x.shape[0], 256), device=x.device)
96
+
97
+ def load_model(self, PATH, device, strict = False):
98
+ state_dict = torch.load(PATH, map_location = device, weights_only=True)
99
+ keys = [k for k in state_dict.keys()]
100
+ if keys[0][:7] == "module.":
101
+ from collections import OrderedDict
102
+ new_state_dict = OrderedDict()
103
+ for k, v in state_dict.items():
104
+ name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
105
+ new_state_dict[name] = v
106
+ self.load_state_dict(new_state_dict, strict = strict)
107
+ else:
108
+ self.load_state_dict(state_dict, strict = strict)
109
+
110
+ if self.dtype != torch.float32:
111
+ self = self.to(self.dtype)
112
+
113
+
114
+ @property
115
+ def device(self):
116
+ """
117
+ Get the device of the model.
118
+
119
+ Returns:
120
+ torch.device: The device of the model.
121
+ """
122
+ return next(self.parameters()).device
123
+
124
+ def save_model(self, filename):
125
+ """
126
+ Save the model to a file.
127
+
128
+ Args:
129
+ filename (str): The path to the file where the model will be saved.
130
+ """
131
+ torch.save(self.state_dict(), filename)
132
+
133
+
134
+
135
+ class CPnetBioImageIO(Transformer):
136
+ """
137
+ A subclass of the CP-SAM model compatible with the BioImage.IO Spec.
138
+
139
+ This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
140
+ allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
141
+ """
142
+
143
+ def forward(self, x):
144
+ """
145
+ Perform a forward pass of the CPnet model and return unpacked tensors.
146
+
147
+ Args:
148
+ x (torch.Tensor): Input tensor.
149
+
150
+ Returns:
151
+ tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
152
+ """
153
+ output_tensor, style_tensor, downsampled_tensors = super().forward(x)
154
+ return output_tensor, style_tensor, *downsampled_tensors
155
+
156
+
157
+ def load_model(self, filename, device=None):
158
+ """
159
+ Load the model from a file.
160
+
161
+ Args:
162
+ filename (str): The path to the file where the model is saved.
163
+ device (torch.device, optional): The device to load the model on. Defaults to None.
164
+ """
165
+ if (device is not None) and (device.type != "cpu"):
166
+ state_dict = torch.load(filename, map_location=device, weights_only=True)
167
+ else:
168
+ self.__init__(self.nout)
169
+ state_dict = torch.load(filename, map_location=torch.device("cpu"),
170
+ weights_only=True)
171
+
172
+ self.load_state_dict(state_dict)
173
+
174
+ def load_state_dict(self, state_dict):
175
+ """
176
+ Load the state dictionary into the model.
177
+
178
+ This method overrides the default `load_state_dict` to handle Cellpose's custom
179
+ loading mechanism and ensures compatibility with BioImage.IO Core.
180
+
181
+ Args:
182
+ state_dict (Mapping[str, Any]): A state dictionary to load into the model
183
+ """
184
+ if state_dict["output.2.weight"].shape[0] != self.nout:
185
+ for name in self.state_dict():
186
+ if "output" not in name:
187
+ self.state_dict()[name].copy_(state_dict[name])
188
+ else:
189
+ super().load_state_dict(
190
+ {name: param for name, param in state_dict.items()},
191
+ strict=False)
192
+
193
+
194
+
195
+
models/seg_post_model/cellpose/vit_sam_new.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
3
+ """
4
+
5
+ import torch
6
+ from segment_anything import sam_model_registry
7
+ torch.backends.cuda.matmul.allow_tf32 = True
8
+ from torch import nn
9
+ import torch.nn.functional as F
10
+
11
+ class Transformer(nn.Module):
12
+ def __init__(self, backbone="vit_l", ps=16, nout=3, bsize=256, rdrop=0.4,
13
+ checkpoint=None, dtype=torch.float32):
14
+ super(Transformer, self).__init__()
15
+ """
16
+ print(self.encoder.patch_embed)
17
+ PatchEmbed(
18
+ (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
19
+ )
20
+ print(self.encoder.neck)
21
+ Sequential(
22
+ (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
23
+ (1): LayerNorm2d()
24
+ (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
25
+ (3): LayerNorm2d()
26
+ )
27
+ """
28
+ # instantiate the vit model, default to not loading SAM
29
+ # checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM
30
+ if checkpoint is None:
31
+ checkpoint = "sam_vit_l_0b3195.pth"
32
+ self.encoder = sam_model_registry[backbone](checkpoint).image_encoder
33
+ w = self.encoder.patch_embed.proj.weight.detach()
34
+ nchan = w.shape[0]
35
+
36
+ # change token size to ps x ps
37
+ self.ps = ps
38
+ # self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
39
+ # self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps]
40
+
41
+ # adjust position embeddings for new bsize and new token size
42
+ ds = (1024 // 16) // (bsize // ps)
43
+ self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True)
44
+
45
+ # readout weights for nout output channels
46
+ # if nout is changed, weights will not load correctly from pretrained Cellpose-SAM
47
+ self.nout = nout
48
+ self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1)
49
+
50
+ # W2 reshapes token space to pixel space, not trainable
51
+ self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps),
52
+ requires_grad=False)
53
+
54
+ # fraction of layers to drop at random during training
55
+ self.rdrop = rdrop
56
+
57
+ # average diameter of ROIs from training images from fine-tuning
58
+ self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False)
59
+ # average diameter of ROIs during main training
60
+ self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False)
61
+
62
+ # set attention to global in every layer
63
+ for blk in self.encoder.blocks:
64
+ blk.window_size = 0
65
+
66
+ self.dtype = dtype
67
+
68
+ def forward(self, x, feat=None):
69
+ # same progression as SAM until readout
70
+ x = self.encoder.patch_embed(x)
71
+ if feat is not None:
72
+ feat = self.encoder.patch_embed(feat)
73
+ x = x + x * feat * 0.5
74
+
75
+ if self.encoder.pos_embed is not None:
76
+ x = x + self.encoder.pos_embed
77
+
78
+ if self.training and self.rdrop > 0:
79
+ nlay = len(self.encoder.blocks)
80
+ rdrop = (torch.rand((len(x), nlay), device=x.device) <
81
+ torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
82
+ for i, blk in enumerate(self.encoder.blocks):
83
+ mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
84
+ x = x * mask + blk(x) * (1-mask)
85
+ else:
86
+ for blk in self.encoder.blocks:
87
+ x = blk(x)
88
+
89
+ x = self.encoder.neck(x.permute(0, 3, 1, 2))
90
+
91
+ # readout is changed here
92
+ x1 = self.out(x)
93
+ x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
94
+
95
+ # maintain the second output of feature size 256 for backwards compatibility
96
+
97
+ return x1, torch.randn((x.shape[0], 256), device=x.device)
98
+
99
+ def load_model(self, PATH, device, strict = False):
100
+ state_dict = torch.load(PATH, map_location = device, weights_only=True)
101
+ keys = [k for k in state_dict.keys()]
102
+ if keys[0][:7] == "module.":
103
+ from collections import OrderedDict
104
+ new_state_dict = OrderedDict()
105
+ for k, v in state_dict.items():
106
+ name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
107
+ new_state_dict[name] = v
108
+ self.load_state_dict(new_state_dict, strict = strict)
109
+ else:
110
+ self.load_state_dict(state_dict, strict = strict)
111
+
112
+ if self.dtype != torch.float32:
113
+ self = self.to(self.dtype)
114
+
115
+
116
+ @property
117
+ def device(self):
118
+ """
119
+ Get the device of the model.
120
+
121
+ Returns:
122
+ torch.device: The device of the model.
123
+ """
124
+ return next(self.parameters()).device
125
+
126
+ def save_model(self, filename):
127
+ """
128
+ Save the model to a file.
129
+
130
+ Args:
131
+ filename (str): The path to the file where the model will be saved.
132
+ """
133
+ torch.save(self.state_dict(), filename)
134
+
135
+
136
+
137
+ class CPnetBioImageIO(Transformer):
138
+ """
139
+ A subclass of the CP-SAM model compatible with the BioImage.IO Spec.
140
+
141
+ This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
142
+ allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
143
+ """
144
+
145
+ def forward(self, x):
146
+ """
147
+ Perform a forward pass of the CPnet model and return unpacked tensors.
148
+
149
+ Args:
150
+ x (torch.Tensor): Input tensor.
151
+
152
+ Returns:
153
+ tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
154
+ """
155
+ output_tensor, style_tensor, downsampled_tensors = super().forward(x)
156
+ return output_tensor, style_tensor, *downsampled_tensors
157
+
158
+
159
+ def load_model(self, filename, device=None):
160
+ """
161
+ Load the model from a file.
162
+
163
+ Args:
164
+ filename (str): The path to the file where the model is saved.
165
+ device (torch.device, optional): The device to load the model on. Defaults to None.
166
+ """
167
+ if (device is not None) and (device.type != "cpu"):
168
+ state_dict = torch.load(filename, map_location=device, weights_only=True)
169
+ else:
170
+ self.__init__(self.nout)
171
+ state_dict = torch.load(filename, map_location=torch.device("cpu"),
172
+ weights_only=True)
173
+
174
+ self.load_state_dict(state_dict)
175
+
176
+ def load_state_dict(self, state_dict):
177
+ """
178
+ Load the state dictionary into the model.
179
+
180
+ This method overrides the default `load_state_dict` to handle Cellpose's custom
181
+ loading mechanism and ensures compatibility with BioImage.IO Core.
182
+
183
+ Args:
184
+ state_dict (Mapping[str, Any]): A state dictionary to load into the model
185
+ """
186
+ if state_dict["output.2.weight"].shape[0] != self.nout:
187
+ for name in self.state_dict():
188
+ if "output" not in name:
189
+ self.state_dict()[name].copy_(state_dict[name])
190
+ else:
191
+ super().load_state_dict(
192
+ {name: param for name, param in state_dict.items()},
193
+ strict=False)
194
+
195
+
196
+
197
+