phoebehxf commited on
Commit
aff3c6f
·
1 Parent(s): 01050f6
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. README.md +5 -4
  2. _utils/attn_utils.py +592 -0
  3. _utils/attn_utils_new.py +610 -0
  4. _utils/config.yaml +15 -0
  5. _utils/example_config.yaml +20 -0
  6. _utils/load_models.py +16 -0
  7. _utils/load_track_data.py +118 -0
  8. _utils/misc_helper.py +37 -0
  9. _utils/seg_eval.py +61 -0
  10. _utils/track_args.py +157 -0
  11. app.py +1638 -0
  12. config.py +44 -0
  13. counting.py +340 -0
  14. example_imgs/cnt/047cell.png +3 -0
  15. example_imgs/cnt/62_10.png +3 -0
  16. example_imgs/cnt/6800-17000_GTEX-XQ3S_Adipose-Subcutaneous.png +3 -0
  17. example_imgs/seg/003_img.png +3 -0
  18. example_imgs/seg/1-23 [Scan I08].png +3 -0
  19. example_imgs/seg/10X_B2_Tile-15.aligned.png +3 -0
  20. example_imgs/seg/1977_Well_F-5_Field_1.png +3 -0
  21. example_imgs/seg/200972823[5179]_RhoGGG_YAP_TAZ [200972823 Well K6 Field #2].png +3 -0
  22. example_imgs/seg/A172_Phase_C7_1_00d00h00m_1.png +3 -0
  23. example_imgs/seg/JE2NileRed_oilp22_PMP_101220_011_NR.png +3 -0
  24. example_imgs/seg/OpenTest_031.png +3 -0
  25. example_imgs/seg/X_24.png +3 -0
  26. example_imgs/seg/exp_A01_G002_0001.oir.png +3 -0
  27. example_imgs/tra/tracking_test_sequence.zip +3 -0
  28. example_imgs/tra/tracking_test_sequence2.zip +3 -0
  29. inference_count.py +237 -0
  30. inference_seg.py +87 -0
  31. inference_track.py +202 -0
  32. models/.DS_Store +0 -0
  33. models/enc_model/__init__.py +0 -0
  34. models/enc_model/backbone.py +64 -0
  35. models/enc_model/loca.py +232 -0
  36. models/enc_model/loca_args.py +44 -0
  37. models/enc_model/mlp.py +23 -0
  38. models/enc_model/ope.py +245 -0
  39. models/enc_model/positional_encoding.py +30 -0
  40. models/enc_model/regression_head.py +92 -0
  41. models/enc_model/transformer.py +94 -0
  42. models/enc_model/unet_parts.py +77 -0
  43. models/model.py +653 -0
  44. models/seg_post_model/cellpose/__init__.py +1 -0
  45. models/seg_post_model/cellpose/__main__.py +272 -0
  46. models/seg_post_model/cellpose/cli.py +240 -0
  47. models/seg_post_model/cellpose/core.py +322 -0
  48. models/seg_post_model/cellpose/denoise.py +1474 -0
  49. models/seg_post_model/cellpose/dynamics.py +691 -0
  50. models/seg_post_model/cellpose/export.py +405 -0
README.md CHANGED
@@ -1,11 +1,12 @@
1
  ---
2
  title: MicroscopyMatching
3
- emoji: 🐢
4
- colorFrom: green
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 6.3.0
8
  app_file: app.py
 
9
  pinned: false
10
  ---
11
 
 
1
  ---
2
  title: MicroscopyMatching
3
+ emoji: 🚀
4
+ colorFrom: gray
5
+ colorTo: red
6
  sdk: gradio
7
+ sdk_version: 5.49.1
8
  app_file: app.py
9
+ python_version: 3.11
10
  pinned: false
11
  ---
12
 
_utils/attn_utils.py ADDED
@@ -0,0 +1,592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from IPython.display import display
7
+ from PIL import Image
8
+ from typing import Union, Tuple, List
9
+ from einops import rearrange, repeat
10
+ import math
11
+ from torch import nn, einsum
12
+ from inspect import isfunction
13
+ from diffusers.utils import logging
14
+ try:
15
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
16
+ except:
17
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
18
+
19
+ try:
20
+ from diffusers.models.cross_attention import CrossAttention
21
+ except:
22
+ from diffusers.models.attention_processor import Attention as CrossAttention
23
+
24
+ MAX_NUM_WORDS = 77
25
+ LOW_RESOURCE = False
26
+
27
+ class CountingCrossAttnProcessor1:
28
+
29
+ def __init__(self, attnstore, place_in_unet):
30
+ super().__init__()
31
+ self.attnstore = attnstore
32
+ self.place_in_unet = place_in_unet
33
+
34
+ def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
35
+ batch_size, sequence_length, dim = hidden_states.shape
36
+ h = attn_layer.heads
37
+ q = attn_layer.to_q(hidden_states)
38
+ is_cross = encoder_hidden_states is not None
39
+ context = encoder_hidden_states if is_cross else hidden_states
40
+ k = attn_layer.to_k(context)
41
+ v = attn_layer.to_v(context)
42
+ # q = attn_layer.reshape_heads_to_batch_dim(q)
43
+ # k = attn_layer.reshape_heads_to_batch_dim(k)
44
+ # v = attn_layer.reshape_heads_to_batch_dim(v)
45
+ # q = attn_layer.head_to_batch_dim(q)
46
+ # k = attn_layer.head_to_batch_dim(k)
47
+ # v = attn_layer.head_to_batch_dim(v)
48
+ q = self.head_to_batch_dim(q, h)
49
+ k = self.head_to_batch_dim(k, h)
50
+ v = self.head_to_batch_dim(v, h)
51
+
52
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale
53
+
54
+ if attention_mask is not None:
55
+ attention_mask = attention_mask.reshape(batch_size, -1)
56
+ max_neg_value = -torch.finfo(sim.dtype).max
57
+ attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
58
+ sim.masked_fill_(~attention_mask, max_neg_value)
59
+
60
+ # attention, what we cannot get enough of
61
+ attn_ = sim.softmax(dim=-1).clone()
62
+ # softmax = nn.Softmax(dim=-1)
63
+ # attn_ = softmax(sim)
64
+ self.attnstore(attn_, is_cross, self.place_in_unet)
65
+ out = torch.einsum("b i j, b j d -> b i d", attn_, v)
66
+ # out = attn_layer.batch_to_head_dim(out)
67
+ out = self.batch_to_head_dim(out, h)
68
+
69
+ if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
70
+ to_out = attn_layer.to_out[0]
71
+ else:
72
+ to_out = attn_layer.to_out
73
+
74
+ out = to_out(out)
75
+ return out
76
+
77
+ def batch_to_head_dim(self, tensor, head_size):
78
+ # head_size = self.heads
79
+ batch_size, seq_len, dim = tensor.shape
80
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
81
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
82
+ return tensor
83
+
84
+ def head_to_batch_dim(self, tensor, head_size, out_dim=3):
85
+ # head_size = self.heads
86
+ batch_size, seq_len, dim = tensor.shape
87
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
88
+ tensor = tensor.permute(0, 2, 1, 3)
89
+
90
+ if out_dim == 3:
91
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
92
+
93
+ return tensor
94
+
95
+
96
+ def register_attention_control(model, controller):
97
+
98
+ attn_procs = {}
99
+ cross_att_count = 0
100
+ for name in model.unet.attn_processors.keys():
101
+ cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
102
+ if name.startswith("mid_block"):
103
+ hidden_size = model.unet.config.block_out_channels[-1]
104
+ place_in_unet = "mid"
105
+ elif name.startswith("up_blocks"):
106
+ block_id = int(name[len("up_blocks.")])
107
+ hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
108
+ place_in_unet = "up"
109
+ elif name.startswith("down_blocks"):
110
+ block_id = int(name[len("down_blocks.")])
111
+ hidden_size = model.unet.config.block_out_channels[block_id]
112
+ place_in_unet = "down"
113
+ else:
114
+ continue
115
+
116
+ cross_att_count += 1
117
+ # attn_procs[name] = AttendExciteCrossAttnProcessor(
118
+ # attnstore=controller, place_in_unet=place_in_unet
119
+ # )
120
+ attn_procs[name] = CountingCrossAttnProcessor1(
121
+ attnstore=controller, place_in_unet=place_in_unet
122
+ )
123
+
124
+ model.unet.set_attn_processor(attn_procs)
125
+ controller.num_att_layers = cross_att_count
126
+
127
+ def register_hier_output(model):
128
+ self = model.unet
129
+ from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding
130
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
131
+ def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None,
132
+ attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None,
133
+ mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True):
134
+
135
+ out_list = []
136
+
137
+
138
+ default_overall_up_factor = 2**self.num_upsamplers
139
+
140
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
141
+ forward_upsample_size = False
142
+ upsample_size = None
143
+
144
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
145
+ logger.info("Forward upsample size to force interpolation output size.")
146
+ forward_upsample_size = True
147
+
148
+ if attention_mask is not None:
149
+ # assume that mask is expressed as:
150
+ # (1 = keep, 0 = discard)
151
+ # convert mask into a bias that can be added to attention scores:
152
+ # (keep = +0, discard = -10000.0)
153
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
154
+ attention_mask = attention_mask.unsqueeze(1)
155
+
156
+ if encoder_attention_mask is not None:
157
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
158
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
159
+
160
+ if self.config.center_input_sample:
161
+ sample = 2 * sample - 1.0
162
+
163
+ timesteps = timestep
164
+ if not torch.is_tensor(timesteps):
165
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
166
+ # This would be a good case for the `match` statement (Python 3.10+)
167
+ is_mps = sample.device.type == "mps"
168
+ if isinstance(timestep, float):
169
+ dtype = torch.float32 if is_mps else torch.float64
170
+ else:
171
+ dtype = torch.int32 if is_mps else torch.int64
172
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
173
+ elif len(timesteps.shape) == 0:
174
+ timesteps = timesteps[None].to(sample.device)
175
+
176
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
177
+ timesteps = timesteps.expand(sample.shape[0])
178
+
179
+ t_emb = self.time_proj(timesteps)
180
+
181
+ t_emb = t_emb.to(dtype=sample.dtype)
182
+
183
+ emb = self.time_embedding(t_emb, timestep_cond)
184
+ aug_emb = None
185
+
186
+ if self.class_embedding is not None:
187
+ if class_labels is None:
188
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
189
+
190
+ if self.config.class_embed_type == "timestep":
191
+ class_labels = self.time_proj(class_labels)
192
+
193
+ # `Timesteps` does not contain any weights and will always return f32 tensors
194
+ # there might be better ways to encapsulate this.
195
+ class_labels = class_labels.to(dtype=sample.dtype)
196
+
197
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
198
+
199
+ if self.config.class_embeddings_concat:
200
+ emb = torch.cat([emb, class_emb], dim=-1)
201
+ else:
202
+ emb = emb + class_emb
203
+
204
+ if self.config.addition_embed_type == "text":
205
+ aug_emb = self.add_embedding(encoder_hidden_states)
206
+ elif self.config.addition_embed_type == "text_image":
207
+ # Kandinsky 2.1 - style
208
+ if "image_embeds" not in added_cond_kwargs:
209
+ raise ValueError(
210
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
211
+ )
212
+
213
+ image_embs = added_cond_kwargs.get("image_embeds")
214
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
215
+ aug_emb = self.add_embedding(text_embs, image_embs)
216
+ elif self.config.addition_embed_type == "text_time":
217
+ # SDXL - style
218
+ if "text_embeds" not in added_cond_kwargs:
219
+ raise ValueError(
220
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
221
+ )
222
+ text_embeds = added_cond_kwargs.get("text_embeds")
223
+ if "time_ids" not in added_cond_kwargs:
224
+ raise ValueError(
225
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
226
+ )
227
+ time_ids = added_cond_kwargs.get("time_ids")
228
+ time_embeds = self.add_time_proj(time_ids.flatten())
229
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
230
+
231
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
232
+ add_embeds = add_embeds.to(emb.dtype)
233
+ aug_emb = self.add_embedding(add_embeds)
234
+ elif self.config.addition_embed_type == "image":
235
+ # Kandinsky 2.2 - style
236
+ if "image_embeds" not in added_cond_kwargs:
237
+ raise ValueError(
238
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
239
+ )
240
+ image_embs = added_cond_kwargs.get("image_embeds")
241
+ aug_emb = self.add_embedding(image_embs)
242
+ elif self.config.addition_embed_type == "image_hint":
243
+ # Kandinsky 2.2 - style
244
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
245
+ raise ValueError(
246
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
247
+ )
248
+ image_embs = added_cond_kwargs.get("image_embeds")
249
+ hint = added_cond_kwargs.get("hint")
250
+ aug_emb, hint = self.add_embedding(image_embs, hint)
251
+ sample = torch.cat([sample, hint], dim=1)
252
+
253
+ emb = emb + aug_emb if aug_emb is not None else emb
254
+
255
+ if self.time_embed_act is not None:
256
+ emb = self.time_embed_act(emb)
257
+
258
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
259
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
260
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
261
+ # Kadinsky 2.1 - style
262
+ if "image_embeds" not in added_cond_kwargs:
263
+ raise ValueError(
264
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
265
+ )
266
+
267
+ image_embeds = added_cond_kwargs.get("image_embeds")
268
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
269
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
270
+ # Kandinsky 2.2 - style
271
+ if "image_embeds" not in added_cond_kwargs:
272
+ raise ValueError(
273
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
274
+ )
275
+ image_embeds = added_cond_kwargs.get("image_embeds")
276
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
277
+ # 2. pre-process
278
+ sample = self.conv_in(sample) # 1, 320, 64, 64
279
+
280
+ # 2.5 GLIGEN position net
281
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
282
+ cross_attention_kwargs = cross_attention_kwargs.copy()
283
+ gligen_args = cross_attention_kwargs.pop("gligen")
284
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
285
+
286
+ # 3. down
287
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
288
+
289
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
290
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
291
+
292
+ down_block_res_samples = (sample,)
293
+
294
+ for downsample_block in self.down_blocks:
295
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
296
+ # For t2i-adapter CrossAttnDownBlock2D
297
+ additional_residuals = {}
298
+ if is_adapter and len(down_block_additional_residuals) > 0:
299
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
300
+
301
+ sample, res_samples = downsample_block(
302
+ hidden_states=sample,
303
+ temb=emb,
304
+ encoder_hidden_states=encoder_hidden_states,
305
+ attention_mask=attention_mask,
306
+ cross_attention_kwargs=cross_attention_kwargs,
307
+ encoder_attention_mask=encoder_attention_mask,
308
+ **additional_residuals,
309
+ )
310
+ else:
311
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
312
+
313
+ if is_adapter and len(down_block_additional_residuals) > 0:
314
+ sample += down_block_additional_residuals.pop(0)
315
+
316
+ down_block_res_samples += res_samples
317
+
318
+ if is_controlnet:
319
+ new_down_block_res_samples = ()
320
+
321
+ for down_block_res_sample, down_block_additional_residual in zip(
322
+ down_block_res_samples, down_block_additional_residuals
323
+ ):
324
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
325
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
326
+
327
+ down_block_res_samples = new_down_block_res_samples
328
+
329
+ # 4. mid
330
+ if self.mid_block is not None:
331
+ sample = self.mid_block(
332
+ sample,
333
+ emb,
334
+ encoder_hidden_states=encoder_hidden_states,
335
+ attention_mask=attention_mask,
336
+ cross_attention_kwargs=cross_attention_kwargs,
337
+ encoder_attention_mask=encoder_attention_mask,
338
+ )
339
+ # To support T2I-Adapter-XL
340
+ if (
341
+ is_adapter
342
+ and len(down_block_additional_residuals) > 0
343
+ and sample.shape == down_block_additional_residuals[0].shape
344
+ ):
345
+ sample += down_block_additional_residuals.pop(0)
346
+
347
+ if is_controlnet:
348
+ sample = sample + mid_block_additional_residual
349
+
350
+ # 5. up
351
+ for i, upsample_block in enumerate(self.up_blocks):
352
+ is_final_block = i == len(self.up_blocks) - 1
353
+
354
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
355
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
356
+
357
+ # if we have not reached the final block and need to forward the
358
+ # upsample size, we do it here
359
+ if not is_final_block and forward_upsample_size:
360
+ upsample_size = down_block_res_samples[-1].shape[2:]
361
+
362
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
363
+ sample = upsample_block(
364
+ hidden_states=sample,
365
+ temb=emb,
366
+ res_hidden_states_tuple=res_samples,
367
+ encoder_hidden_states=encoder_hidden_states,
368
+ cross_attention_kwargs=cross_attention_kwargs,
369
+ upsample_size=upsample_size,
370
+ attention_mask=attention_mask,
371
+ encoder_attention_mask=encoder_attention_mask,
372
+ )
373
+ else:
374
+ sample = upsample_block(
375
+ hidden_states=sample,
376
+ temb=emb,
377
+ res_hidden_states_tuple=res_samples,
378
+ upsample_size=upsample_size,
379
+ scale=lora_scale,
380
+ )
381
+
382
+ # if i in [1, 4, 7]:
383
+ out_list.append(sample)
384
+
385
+ # 6. post-process
386
+ if self.conv_norm_out:
387
+ sample = self.conv_norm_out(sample)
388
+ sample = self.conv_act(sample)
389
+ sample = self.conv_out(sample)
390
+
391
+ if not return_dict:
392
+ return (sample,)
393
+
394
+ return UNet2DConditionOutput(sample=sample), out_list
395
+
396
+ self.forward = forward
397
+
398
+
399
+ class AttentionControl(abc.ABC):
400
+
401
+ def step_callback(self, x_t):
402
+ return x_t
403
+
404
+ def between_steps(self):
405
+ return
406
+
407
+ @property
408
+ def num_uncond_att_layers(self):
409
+ return 0
410
+
411
+ @abc.abstractmethod
412
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
413
+ raise NotImplementedError
414
+
415
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
416
+ if self.cur_att_layer >= self.num_uncond_att_layers:
417
+ # self.forward(attn, is_cross, place_in_unet)
418
+ if LOW_RESOURCE:
419
+ attn = self.forward(attn, is_cross, place_in_unet)
420
+ else:
421
+ h = attn.shape[0]
422
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
423
+ self.cur_att_layer += 1
424
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
425
+ self.cur_att_layer = 0
426
+ self.cur_step += 1
427
+ self.between_steps()
428
+ return attn
429
+
430
+ def reset(self):
431
+ self.cur_step = 0
432
+ self.cur_att_layer = 0
433
+
434
+ def __init__(self):
435
+ self.cur_step = 0
436
+ self.num_att_layers = -1
437
+ self.cur_att_layer = 0
438
+
439
+
440
+ class EmptyControl(AttentionControl):
441
+
442
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
443
+ return attn
444
+
445
+
446
+ class AttentionStore(AttentionControl):
447
+
448
+ @staticmethod
449
+ def get_empty_store():
450
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
451
+ "down_self": [], "mid_self": [], "up_self": []}
452
+
453
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
454
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
455
+ if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead
456
+ self.step_store[key].append(attn)
457
+ return attn
458
+
459
+ def between_steps(self):
460
+ self.attention_store = self.step_store
461
+ if self.save_global_store:
462
+ with torch.no_grad():
463
+ if len(self.global_store) == 0:
464
+ self.global_store = self.step_store
465
+ else:
466
+ for key in self.global_store:
467
+ for i in range(len(self.global_store[key])):
468
+ self.global_store[key][i] += self.step_store[key][i].detach()
469
+ self.step_store = self.get_empty_store()
470
+ self.step_store = self.get_empty_store()
471
+
472
+ def get_average_attention(self):
473
+ average_attention = self.attention_store
474
+ return average_attention
475
+
476
+ def get_average_global_attention(self):
477
+ average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
478
+ self.attention_store}
479
+ return average_attention
480
+
481
+ def reset(self):
482
+ super(AttentionStore, self).reset()
483
+ self.step_store = self.get_empty_store()
484
+ self.attention_store = {}
485
+ self.global_store = {}
486
+
487
+ def __init__(self, max_size=32, save_global_store=False):
488
+ '''
489
+ Initialize an empty AttentionStore
490
+ :param step_index: used to visualize only a specific step in the diffusion process
491
+ '''
492
+ super(AttentionStore, self).__init__()
493
+ self.save_global_store = save_global_store
494
+ self.max_size = max_size
495
+ self.step_store = self.get_empty_store()
496
+ self.attention_store = {}
497
+ self.global_store = {}
498
+ self.curr_step_index = 0
499
+
500
+ def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
501
+ out = []
502
+ attention_maps = attention_store.get_average_attention()
503
+ num_pixels = res ** 2
504
+ for location in from_where:
505
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
506
+ if item.shape[1] == num_pixels:
507
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
508
+ out.append(cross_maps)
509
+ out = torch.cat(out, dim=0)
510
+ out = out.sum(0) / out.shape[0]
511
+ return out
512
+
513
+
514
+ def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
515
+ tokens = tokenizer.encode(prompts[select])
516
+ decoder = tokenizer.decode
517
+ attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
518
+ images = []
519
+ for i in range(len(tokens)):
520
+ image = attention_maps[:, :, i]
521
+ image = 255 * image / image.max()
522
+ image = image.unsqueeze(-1).expand(*image.shape, 3)
523
+ image = image.numpy().astype(np.uint8)
524
+ image = np.array(Image.fromarray(image).resize((256, 256)))
525
+ image = text_under_image(image, decoder(int(tokens[i])))
526
+ images.append(image)
527
+ view_images(np.stack(images, axis=0))
528
+
529
+
530
+ def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
531
+ max_com=10, select: int = 0):
532
+ attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
533
+ u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
534
+ images = []
535
+ for i in range(max_com):
536
+ image = vh[i].reshape(res, res)
537
+ image = image - image.min()
538
+ image = 255 * image / image.max()
539
+ image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
540
+ image = Image.fromarray(image).resize((256, 256))
541
+ image = np.array(image)
542
+ images.append(image)
543
+ view_images(np.concatenate(images, axis=1))
544
+
545
+ def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
546
+ h, w, c = image.shape
547
+ offset = int(h * .2)
548
+ img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
549
+ font = cv2.FONT_HERSHEY_SIMPLEX
550
+ # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
551
+ img[:h] = image
552
+ textsize = cv2.getTextSize(text, font, 1, 2)[0]
553
+ text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
554
+ cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
555
+ return img
556
+
557
+
558
+ def view_images(images, num_rows=1, offset_ratio=0.02):
559
+ if type(images) is list:
560
+ num_empty = len(images) % num_rows
561
+ elif images.ndim == 4:
562
+ num_empty = images.shape[0] % num_rows
563
+ else:
564
+ images = [images]
565
+ num_empty = 0
566
+
567
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
568
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
569
+ num_items = len(images)
570
+
571
+ h, w, c = images[0].shape
572
+ offset = int(h * offset_ratio)
573
+ num_cols = num_items // num_rows
574
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
575
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
576
+ for i in range(num_rows):
577
+ for j in range(num_cols):
578
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
579
+ i * num_cols + j]
580
+
581
+ pil_img = Image.fromarray(image_)
582
+ display(pil_img)
583
+
584
+ def self_cross_attn(self_attn, cross_attn):
585
+ res = self_attn.shape[0]
586
+ assert res == cross_attn.shape[0]
587
+ # cross attn [res, res] -> [res*res]
588
+ cross_attn_ = cross_attn.reshape([res*res])
589
+ # self_attn [res, res, res*res]
590
+ self_cross_attn = cross_attn_ * self_attn
591
+ self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0)
592
+ return self_cross_attn
_utils/attn_utils_new.py ADDED
@@ -0,0 +1,610 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import abc
2
+
3
+ import cv2
4
+ import numpy as np
5
+ import torch
6
+ from IPython.display import display
7
+ from PIL import Image
8
+ from typing import Union, Tuple, List
9
+ from einops import rearrange, repeat
10
+ import math
11
+ from torch import nn, einsum
12
+ from inspect import isfunction
13
+ from diffusers.utils import logging
14
+ try:
15
+ from diffusers.models.unet_2d_condition import UNet2DConditionOutput
16
+ except:
17
+ from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
18
+ try:
19
+ from diffusers.models.cross_attention import CrossAttention
20
+ except:
21
+ from diffusers.models.attention_processor import Attention as CrossAttention
22
+ from typing import Any, Dict, List, Optional, Tuple, Union
23
+ MAX_NUM_WORDS = 77
24
+ LOW_RESOURCE = False
25
+
26
+ class CountingCrossAttnProcessor1:
27
+
28
+ def __init__(self, attnstore, place_in_unet):
29
+ super().__init__()
30
+ self.attnstore = attnstore
31
+ self.place_in_unet = place_in_unet
32
+
33
+ def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
34
+ batch_size, sequence_length, dim = hidden_states.shape
35
+ h = attn_layer.heads
36
+ q = attn_layer.to_q(hidden_states)
37
+ is_cross = encoder_hidden_states is not None
38
+ context = encoder_hidden_states if is_cross else hidden_states
39
+ k = attn_layer.to_k(context)
40
+ v = attn_layer.to_v(context)
41
+ # q = attn_layer.reshape_heads_to_batch_dim(q)
42
+ # k = attn_layer.reshape_heads_to_batch_dim(k)
43
+ # v = attn_layer.reshape_heads_to_batch_dim(v)
44
+ # q = attn_layer.head_to_batch_dim(q)
45
+ # k = attn_layer.head_to_batch_dim(k)
46
+ # v = attn_layer.head_to_batch_dim(v)
47
+ q = self.head_to_batch_dim(q, h)
48
+ k = self.head_to_batch_dim(k, h)
49
+ v = self.head_to_batch_dim(v, h)
50
+
51
+ sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale
52
+
53
+ if attention_mask is not None:
54
+ attention_mask = attention_mask.reshape(batch_size, -1)
55
+ max_neg_value = -torch.finfo(sim.dtype).max
56
+ attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
57
+ sim.masked_fill_(~attention_mask, max_neg_value)
58
+
59
+ # attention, what we cannot get enough of
60
+ attn_ = sim.softmax(dim=-1).clone()
61
+ # softmax = nn.Softmax(dim=-1)
62
+ # attn_ = softmax(sim)
63
+ self.attnstore(attn_, is_cross, self.place_in_unet)
64
+ out = torch.einsum("b i j, b j d -> b i d", attn_, v)
65
+ # out = attn_layer.batch_to_head_dim(out)
66
+ out = self.batch_to_head_dim(out, h)
67
+
68
+ if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
69
+ to_out = attn_layer.to_out[0]
70
+ else:
71
+ to_out = attn_layer.to_out
72
+
73
+ out = to_out(out)
74
+ return out
75
+
76
+ def batch_to_head_dim(self, tensor, head_size):
77
+ # head_size = self.heads
78
+ batch_size, seq_len, dim = tensor.shape
79
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
80
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
81
+ return tensor
82
+
83
+ def head_to_batch_dim(self, tensor, head_size, out_dim=3):
84
+ # head_size = self.heads
85
+ batch_size, seq_len, dim = tensor.shape
86
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
87
+ tensor = tensor.permute(0, 2, 1, 3)
88
+
89
+ if out_dim == 3:
90
+ tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
91
+
92
+ return tensor
93
+
94
+
95
+ def register_attention_control(model, controller):
96
+
97
+ attn_procs = {}
98
+ cross_att_count = 0
99
+ for name in model.unet.attn_processors.keys():
100
+ cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
101
+ if name.startswith("mid_block"):
102
+ hidden_size = model.unet.config.block_out_channels[-1]
103
+ place_in_unet = "mid"
104
+ elif name.startswith("up_blocks"):
105
+ block_id = int(name[len("up_blocks.")])
106
+ hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
107
+ place_in_unet = "up"
108
+ elif name.startswith("down_blocks"):
109
+ block_id = int(name[len("down_blocks.")])
110
+ hidden_size = model.unet.config.block_out_channels[block_id]
111
+ place_in_unet = "down"
112
+ else:
113
+ continue
114
+
115
+ cross_att_count += 1
116
+ # attn_procs[name] = AttendExciteCrossAttnProcessor(
117
+ # attnstore=controller, place_in_unet=place_in_unet
118
+ # )
119
+ attn_procs[name] = CountingCrossAttnProcessor1(
120
+ attnstore=controller, place_in_unet=place_in_unet
121
+ )
122
+
123
+ model.unet.set_attn_processor(attn_procs)
124
+ controller.num_att_layers = cross_att_count
125
+
126
+ def register_hier_output(model):
127
+ self = model.unet
128
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
129
+ def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None,
130
+ attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None,
131
+ mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True):
132
+
133
+ out_list = []
134
+
135
+
136
+ default_overall_up_factor = 2**self.num_upsamplers
137
+
138
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
139
+ forward_upsample_size = False
140
+ upsample_size = None
141
+
142
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
143
+ logger.info("Forward upsample size to force interpolation output size.")
144
+ forward_upsample_size = True
145
+
146
+ if attention_mask is not None:
147
+ # assume that mask is expressed as:
148
+ # (1 = keep, 0 = discard)
149
+ # convert mask into a bias that can be added to attention scores:
150
+ # (keep = +0, discard = -10000.0)
151
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
152
+ attention_mask = attention_mask.unsqueeze(1)
153
+
154
+ if encoder_attention_mask is not None:
155
+ encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
156
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
157
+
158
+ if self.config.center_input_sample:
159
+ sample = 2 * sample - 1.0
160
+
161
+ timesteps = timestep
162
+ if not torch.is_tensor(timesteps):
163
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
164
+ # This would be a good case for the `match` statement (Python 3.10+)
165
+ is_mps = sample.device.type == "mps"
166
+ if isinstance(timestep, float):
167
+ dtype = torch.float32 if is_mps else torch.float64
168
+ else:
169
+ dtype = torch.int32 if is_mps else torch.int64
170
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
171
+ elif len(timesteps.shape) == 0:
172
+ timesteps = timesteps[None].to(sample.device)
173
+
174
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
175
+ timesteps = timesteps.expand(sample.shape[0])
176
+
177
+ t_emb = self.time_proj(timesteps)
178
+
179
+ t_emb = t_emb.to(dtype=sample.dtype)
180
+
181
+ emb = self.time_embedding(t_emb, timestep_cond)
182
+ aug_emb = None
183
+
184
+ if self.class_embedding is not None:
185
+ if class_labels is None:
186
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
187
+
188
+ if self.config.class_embed_type == "timestep":
189
+ class_labels = self.time_proj(class_labels)
190
+
191
+ # `Timesteps` does not contain any weights and will always return f32 tensors
192
+ # there might be better ways to encapsulate this.
193
+ class_labels = class_labels.to(dtype=sample.dtype)
194
+
195
+ class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
196
+
197
+ if self.config.class_embeddings_concat:
198
+ emb = torch.cat([emb, class_emb], dim=-1)
199
+ else:
200
+ emb = emb + class_emb
201
+
202
+ if self.config.addition_embed_type == "text":
203
+ aug_emb = self.add_embedding(encoder_hidden_states)
204
+ elif self.config.addition_embed_type == "text_image":
205
+ # Kandinsky 2.1 - style
206
+ if "image_embeds" not in added_cond_kwargs:
207
+ raise ValueError(
208
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
209
+ )
210
+
211
+ image_embs = added_cond_kwargs.get("image_embeds")
212
+ text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
213
+ aug_emb = self.add_embedding(text_embs, image_embs)
214
+ elif self.config.addition_embed_type == "text_time":
215
+ # SDXL - style
216
+ if "text_embeds" not in added_cond_kwargs:
217
+ raise ValueError(
218
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
219
+ )
220
+ text_embeds = added_cond_kwargs.get("text_embeds")
221
+ if "time_ids" not in added_cond_kwargs:
222
+ raise ValueError(
223
+ f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
224
+ )
225
+ time_ids = added_cond_kwargs.get("time_ids")
226
+ time_embeds = self.add_time_proj(time_ids.flatten())
227
+ time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
228
+
229
+ add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
230
+ add_embeds = add_embeds.to(emb.dtype)
231
+ aug_emb = self.add_embedding(add_embeds)
232
+ elif self.config.addition_embed_type == "image":
233
+ # Kandinsky 2.2 - style
234
+ if "image_embeds" not in added_cond_kwargs:
235
+ raise ValueError(
236
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
237
+ )
238
+ image_embs = added_cond_kwargs.get("image_embeds")
239
+ aug_emb = self.add_embedding(image_embs)
240
+ elif self.config.addition_embed_type == "image_hint":
241
+ # Kandinsky 2.2 - style
242
+ if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
243
+ raise ValueError(
244
+ f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
245
+ )
246
+ image_embs = added_cond_kwargs.get("image_embeds")
247
+ hint = added_cond_kwargs.get("hint")
248
+ aug_emb, hint = self.add_embedding(image_embs, hint)
249
+ sample = torch.cat([sample, hint], dim=1)
250
+
251
+ emb = emb + aug_emb if aug_emb is not None else emb
252
+
253
+ if self.time_embed_act is not None:
254
+ emb = self.time_embed_act(emb)
255
+
256
+ if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
257
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
258
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
259
+ # Kadinsky 2.1 - style
260
+ if "image_embeds" not in added_cond_kwargs:
261
+ raise ValueError(
262
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
263
+ )
264
+
265
+ image_embeds = added_cond_kwargs.get("image_embeds")
266
+ encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
267
+ elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
268
+ # Kandinsky 2.2 - style
269
+ if "image_embeds" not in added_cond_kwargs:
270
+ raise ValueError(
271
+ f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
272
+ )
273
+ image_embeds = added_cond_kwargs.get("image_embeds")
274
+ encoder_hidden_states = self.encoder_hid_proj(image_embeds)
275
+ # 2. pre-process
276
+ sample = self.conv_in(sample) # 1, 320, 64, 64
277
+
278
+ # 2.5 GLIGEN position net
279
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
280
+ cross_attention_kwargs = cross_attention_kwargs.copy()
281
+ gligen_args = cross_attention_kwargs.pop("gligen")
282
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
283
+
284
+ # 3. down
285
+ lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
286
+
287
+ is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
288
+ is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
289
+
290
+ down_block_res_samples = (sample,)
291
+
292
+ for downsample_block in self.down_blocks:
293
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
294
+ # For t2i-adapter CrossAttnDownBlock2D
295
+ additional_residuals = {}
296
+ if is_adapter and len(down_block_additional_residuals) > 0:
297
+ additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
298
+
299
+ sample, res_samples = downsample_block(
300
+ hidden_states=sample,
301
+ temb=emb,
302
+ encoder_hidden_states=encoder_hidden_states,
303
+ attention_mask=attention_mask,
304
+ cross_attention_kwargs=cross_attention_kwargs,
305
+ encoder_attention_mask=encoder_attention_mask,
306
+ **additional_residuals,
307
+ )
308
+ else:
309
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
310
+
311
+ if is_adapter and len(down_block_additional_residuals) > 0:
312
+ sample += down_block_additional_residuals.pop(0)
313
+
314
+ down_block_res_samples += res_samples
315
+
316
+ if is_controlnet:
317
+ new_down_block_res_samples = ()
318
+
319
+ for down_block_res_sample, down_block_additional_residual in zip(
320
+ down_block_res_samples, down_block_additional_residuals
321
+ ):
322
+ down_block_res_sample = down_block_res_sample + down_block_additional_residual
323
+ new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
324
+
325
+ down_block_res_samples = new_down_block_res_samples
326
+
327
+ # 4. mid
328
+ if self.mid_block is not None:
329
+ sample = self.mid_block(
330
+ sample,
331
+ emb,
332
+ encoder_hidden_states=encoder_hidden_states,
333
+ attention_mask=attention_mask,
334
+ cross_attention_kwargs=cross_attention_kwargs,
335
+ encoder_attention_mask=encoder_attention_mask,
336
+ )
337
+ # To support T2I-Adapter-XL
338
+ if (
339
+ is_adapter
340
+ and len(down_block_additional_residuals) > 0
341
+ and sample.shape == down_block_additional_residuals[0].shape
342
+ ):
343
+ sample += down_block_additional_residuals.pop(0)
344
+
345
+ if is_controlnet:
346
+ sample = sample + mid_block_additional_residual
347
+
348
+ # 5. up
349
+ for i, upsample_block in enumerate(self.up_blocks):
350
+ is_final_block = i == len(self.up_blocks) - 1
351
+
352
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
353
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
354
+
355
+ # if we have not reached the final block and need to forward the
356
+ # upsample size, we do it here
357
+ if not is_final_block and forward_upsample_size:
358
+ upsample_size = down_block_res_samples[-1].shape[2:]
359
+
360
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
361
+ sample = upsample_block(
362
+ hidden_states=sample,
363
+ temb=emb,
364
+ res_hidden_states_tuple=res_samples,
365
+ encoder_hidden_states=encoder_hidden_states,
366
+ cross_attention_kwargs=cross_attention_kwargs,
367
+ upsample_size=upsample_size,
368
+ attention_mask=attention_mask,
369
+ encoder_attention_mask=encoder_attention_mask,
370
+ )
371
+ else:
372
+ sample = upsample_block(
373
+ hidden_states=sample,
374
+ temb=emb,
375
+ res_hidden_states_tuple=res_samples,
376
+ upsample_size=upsample_size,
377
+ scale=lora_scale,
378
+ )
379
+
380
+ out_list.append(sample)
381
+
382
+ # 6. post-process
383
+ if self.conv_norm_out:
384
+ sample = self.conv_norm_out(sample)
385
+ sample = self.conv_act(sample)
386
+ sample = self.conv_out(sample)
387
+
388
+ if not return_dict:
389
+ return (sample,)
390
+
391
+ return UNet2DConditionOutput(sample=sample), out_list
392
+
393
+ self.forward = forward
394
+
395
+
396
+
397
+
398
+
399
+
400
+
401
+ class AttentionControl(abc.ABC):
402
+
403
+ def step_callback(self, x_t):
404
+ return x_t
405
+
406
+ def between_steps(self):
407
+ return
408
+
409
+ @property
410
+ def num_uncond_att_layers(self):
411
+ return 0
412
+
413
+ @abc.abstractmethod
414
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
415
+ raise NotImplementedError
416
+
417
+ def __call__(self, attn, is_cross: bool, place_in_unet: str):
418
+ if self.cur_att_layer >= self.num_uncond_att_layers:
419
+ # self.forward(attn, is_cross, place_in_unet)
420
+ if LOW_RESOURCE:
421
+ attn = self.forward(attn, is_cross, place_in_unet)
422
+ else:
423
+ h = attn.shape[0]
424
+ attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
425
+ self.cur_att_layer += 1
426
+ if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
427
+ self.cur_att_layer = 0
428
+ self.cur_step += 1
429
+ self.between_steps()
430
+ return attn
431
+
432
+ def reset(self):
433
+ self.cur_step = 0
434
+ self.cur_att_layer = 0
435
+
436
+ def __init__(self):
437
+ self.cur_step = 0
438
+ self.num_att_layers = -1
439
+ self.cur_att_layer = 0
440
+
441
+
442
+ class EmptyControl(AttentionControl):
443
+
444
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
445
+ return attn
446
+
447
+
448
+ class AttentionStore(AttentionControl):
449
+
450
+ @staticmethod
451
+ def get_empty_store():
452
+ return {"down_cross": [], "mid_cross": [], "up_cross": [],
453
+ "down_self": [], "mid_self": [], "up_self": []}
454
+
455
+ def forward(self, attn, is_cross: bool, place_in_unet: str):
456
+ key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
457
+ if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead
458
+ self.step_store[key].append(attn)
459
+ return attn
460
+
461
+ def between_steps(self):
462
+ self.attention_store = self.step_store
463
+ if self.save_global_store:
464
+ with torch.no_grad():
465
+ if len(self.global_store) == 0:
466
+ self.global_store = self.step_store
467
+ else:
468
+ for key in self.global_store:
469
+ for i in range(len(self.global_store[key])):
470
+ self.global_store[key][i] += self.step_store[key][i].detach()
471
+ self.step_store = self.get_empty_store()
472
+ self.step_store = self.get_empty_store()
473
+
474
+ def get_average_attention(self):
475
+ average_attention = self.attention_store
476
+ return average_attention
477
+
478
+ def get_average_global_attention(self):
479
+ average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
480
+ self.attention_store}
481
+ return average_attention
482
+
483
+ def reset(self):
484
+ super(AttentionStore, self).reset()
485
+ self.step_store = self.get_empty_store()
486
+ self.attention_store = {}
487
+ self.global_store = {}
488
+
489
+ def __init__(self, max_size=32, save_global_store=False):
490
+ '''
491
+ Initialize an empty AttentionStore
492
+ :param step_index: used to visualize only a specific step in the diffusion process
493
+ '''
494
+ super(AttentionStore, self).__init__()
495
+ self.save_global_store = save_global_store
496
+ self.max_size = max_size
497
+ self.step_store = self.get_empty_store()
498
+ self.attention_store = {}
499
+ self.global_store = {}
500
+ self.curr_step_index = 0
501
+
502
+ def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
503
+ out = []
504
+ attention_maps = attention_store.get_average_attention()
505
+ num_pixels = res ** 2
506
+ for location in from_where:
507
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
508
+ if item.shape[1] == num_pixels:
509
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
510
+ out.append(cross_maps)
511
+ out = torch.cat(out, dim=0)
512
+ out = out.sum(0) / out.shape[0]
513
+ return out
514
+
515
+ def aggregate_attention1(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
516
+ out = []
517
+ attention_maps = attention_store.get_average_attention()
518
+ num_pixels = res ** 2
519
+ for location in from_where:
520
+ for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
521
+ if item.shape[1] == num_pixels:
522
+ cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
523
+ out.append(cross_maps)
524
+ # out = torch.cat(out, dim=0)
525
+ # out = out.sum(0) / out.shape[0]
526
+ out = out[1]
527
+ out = out.sum(0) / out.shape[0]
528
+ return out
529
+
530
+
531
+ def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
532
+ tokens = tokenizer.encode(prompts[select])
533
+ decoder = tokenizer.decode
534
+ attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
535
+ images = []
536
+ for i in range(len(tokens)):
537
+ image = attention_maps[:, :, i]
538
+ image = 255 * image / image.max()
539
+ image = image.unsqueeze(-1).expand(*image.shape, 3)
540
+ image = image.numpy().astype(np.uint8)
541
+ image = np.array(Image.fromarray(image).resize((256, 256)))
542
+ image = text_under_image(image, decoder(int(tokens[i])))
543
+ images.append(image)
544
+ view_images(np.stack(images, axis=0))
545
+
546
+
547
+ def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
548
+ max_com=10, select: int = 0):
549
+ attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
550
+ u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
551
+ images = []
552
+ for i in range(max_com):
553
+ image = vh[i].reshape(res, res)
554
+ image = image - image.min()
555
+ image = 255 * image / image.max()
556
+ image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
557
+ image = Image.fromarray(image).resize((256, 256))
558
+ image = np.array(image)
559
+ images.append(image)
560
+ view_images(np.concatenate(images, axis=1))
561
+
562
+ def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
563
+ h, w, c = image.shape
564
+ offset = int(h * .2)
565
+ img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
566
+ font = cv2.FONT_HERSHEY_SIMPLEX
567
+ # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
568
+ img[:h] = image
569
+ textsize = cv2.getTextSize(text, font, 1, 2)[0]
570
+ text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
571
+ cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
572
+ return img
573
+
574
+
575
+ def view_images(images, num_rows=1, offset_ratio=0.02):
576
+ if type(images) is list:
577
+ num_empty = len(images) % num_rows
578
+ elif images.ndim == 4:
579
+ num_empty = images.shape[0] % num_rows
580
+ else:
581
+ images = [images]
582
+ num_empty = 0
583
+
584
+ empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
585
+ images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
586
+ num_items = len(images)
587
+
588
+ h, w, c = images[0].shape
589
+ offset = int(h * offset_ratio)
590
+ num_cols = num_items // num_rows
591
+ image_ = np.ones((h * num_rows + offset * (num_rows - 1),
592
+ w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
593
+ for i in range(num_rows):
594
+ for j in range(num_cols):
595
+ image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
596
+ i * num_cols + j]
597
+
598
+ pil_img = Image.fromarray(image_)
599
+ display(pil_img)
600
+
601
+ def self_cross_attn(self_attn, cross_attn):
602
+ cross_attn = cross_attn.squeeze()
603
+ res = self_attn.shape[0]
604
+ assert res == cross_attn.shape[-1]
605
+ # cross attn [res, res] -> [res*res]
606
+ cross_attn_ = cross_attn.reshape([res*res])
607
+ # self_attn [res, res, res*res]
608
+ self_cross_attn = cross_attn_ * self_attn
609
+ self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0)
610
+ return self_cross_attn
_utils/config.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ attn_dist_mode: v0
2
+ attn_positional_bias: rope
3
+ attn_positional_bias_n_spatial: 16
4
+ causal_norm: quiet_softmax
5
+ coord_dim: 2
6
+ d_model: 320
7
+ dropout: 0.0
8
+ feat_dim: 7
9
+ feat_embed_per_dim: 8
10
+ nhead: 4
11
+ num_decoder_layers: 6
12
+ num_encoder_layers: 6
13
+ pos_embed_per_dim: 32
14
+ spatial_pos_cutoff: 256
15
+ window: 4
_utils/example_config.yaml ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ batch_size: 1
2
+ crop_size:
3
+ - 256
4
+ - 256
5
+ detection_folders:
6
+ - TRA
7
+ dropout: 0.01
8
+ example_images: False # Slow
9
+ input_train:
10
+ - data/ctc/Fluo-N2DL-HeLa/01
11
+ input_val:
12
+ - data/ctc/Fluo-N2DL-HeLa/02
13
+ max_tokens: 2048
14
+ name: example
15
+ ndim: 2
16
+ num_decoder_layers: 5
17
+ num_encoder_layers: 5
18
+ outdir: runs
19
+ distributed: False
20
+ window: 4
_utils/load_models.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from config import RunConfig
2
+ import torch
3
+ from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
4
+ import torch.nn as nn
5
+
6
+ def load_stable_diffusion_model(config: RunConfig):
7
+ device = torch.device('cpu')
8
+
9
+ if config.sd_2_1:
10
+ stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base"
11
+ else:
12
+ stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
13
+ # stable = StableCountingPipeline.from_pretrained(stable_diffusion_version).to(device)
14
+ stable = StableDiffusionPipeline.from_pretrained(stable_diffusion_version).to(device)
15
+ return stable
16
+
_utils/load_track_data.py ADDED
@@ -0,0 +1,118 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from glob import glob
3
+ from pathlib import Path
4
+ from natsort import natsorted
5
+ from PIL import Image
6
+ import numpy as np
7
+ import tifffile
8
+ import skimage.io as io
9
+ import torchvision.transforms as T
10
+ import cv2
11
+ from tqdm import tqdm
12
+ from models.tra_post_model.trackastra.utils import normalize_01, normalize
13
+ IMG_SIZE = 512
14
+
15
+ def _load_tiffs(folder: Path, dtype=None):
16
+ """Load a sequence of tiff files from a folder into a 3D numpy array."""
17
+ images = glob(str(folder / "*.tif"))
18
+ test_data = tifffile.imread(images[0])
19
+ if len(test_data.shape) == 3:
20
+ turn_gray = True
21
+ else:
22
+ turn_gray = False
23
+ end_frame = len(images)
24
+ if not turn_gray:
25
+ x = np.stack([
26
+ tifffile.imread(f).astype(dtype)
27
+ for f in tqdm(
28
+ sorted(folder.glob("*.tif"))[0 : end_frame : 1],
29
+ leave=False,
30
+ desc=f"Loading [0:{end_frame}]",
31
+ )
32
+ ])
33
+ else:
34
+ x = []
35
+ for f in tqdm(
36
+ sorted(folder.glob("*.tif"))[0 : end_frame : 1],
37
+ leave=False,
38
+ desc=f"Loading [0:{end_frame}]",
39
+ ):
40
+ img = tifffile.imread(f).astype(dtype)
41
+ if img.ndim == 3:
42
+ if img.shape[-1] > 3:
43
+ img = img[..., :3]
44
+ img = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2])
45
+ x.append(img)
46
+ x = np.stack(x)
47
+ return x
48
+
49
+
50
+ def load_track_images(file_dir):
51
+
52
+ # suffix_ = [".png", ".tif", ".tiff", ".jpg"]
53
+ def find_tif_dir(root_dir):
54
+ """递归查找.tif 文件"""
55
+ tif_files = []
56
+ for dirpath, _, filenames in os.walk(root_dir):
57
+ if '__MACOSX' in dirpath:
58
+ continue
59
+ for f in filenames:
60
+ if f.lower().endswith('.tif'):
61
+ tif_files.append(os.path.join(dirpath, f))
62
+ return tif_files
63
+
64
+ tif_dir = find_tif_dir(file_dir)
65
+ print(f"Found {len(tif_dir)} tif images in {file_dir}")
66
+ print(f"First 5 tif images: {tif_dir[:5]}")
67
+ assert len(tif_dir) > 0, f"No tif images found in {file_dir}"
68
+ images = natsorted(tif_dir)
69
+ imgs = []
70
+ imgs_raw = []
71
+ images_stable = []
72
+ # load images for seg and track
73
+ for img_path in tqdm(images, desc="Loading images"):
74
+ img = tifffile.imread(img_path)
75
+ img_raw = io.imread(img_path)
76
+
77
+ if img.dtype == 'uint16':
78
+ img = ((img - img.min()) / (img.max() - img.min() + 1e-6) * 255).astype(np.uint8)
79
+ img = np.stack([img] * 3, axis=-1)
80
+ w, h = img.shape[1], img.shape[0]
81
+ else:
82
+ img = Image.open(img_path).convert("RGB")
83
+ w, h = img.size
84
+
85
+ img = T.Compose([
86
+ T.ToTensor(),
87
+ T.Resize((IMG_SIZE, IMG_SIZE)),
88
+ ])(img)
89
+
90
+ image_stable = img - 0.5
91
+ img = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
92
+
93
+
94
+ imgs.append(img)
95
+ imgs_raw.append(img_raw)
96
+ images_stable.append(image_stable)
97
+
98
+ height = h
99
+ width = w
100
+ imgs = np.stack(imgs, axis=0)
101
+ imgs_raw = np.stack(imgs_raw, axis=0)
102
+ images_stable = np.stack(images_stable, axis=0)
103
+
104
+ # track data
105
+ imgs_ = _load_tiffs(Path(file_dir), dtype=np.float32)
106
+ imgs_01 = np.stack([
107
+ normalize_01(_x) for _x in tqdm(imgs_, desc="Normalizing", leave=False)
108
+ ])
109
+ imgs_ = np.stack([
110
+ normalize(_x) for _x in tqdm(imgs_, desc="Normalizing", leave=False)
111
+ ])
112
+
113
+ return imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width
114
+
115
+ if __name__ == "__main__":
116
+ file_dir = "data/2D+Time/DIC-C2DH-HeLa/train/DIC-C2DH-HeLa/02"
117
+ imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width = load_track_images(file_dir)
118
+ print(imgs.shape, imgs_raw.shape, images_stable.shape, imgs_.shape, imgs_01.shape, height, width)
_utils/misc_helper.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import logging
2
+ import os
3
+ import random
4
+ import shutil
5
+ from collections.abc import Mapping
6
+ from datetime import datetime
7
+
8
+ import numpy as np
9
+ import torch
10
+ import torch.distributed as dist
11
+
12
+
13
+ def basicConfig(*args, **kwargs):
14
+ return
15
+
16
+
17
+ # To prevent duplicate logs, we mask this baseConfig setting
18
+ logging.basicConfig = basicConfig
19
+
20
+
21
+ def create_logger(name, log_file, level=logging.INFO):
22
+ log = logging.getLogger(name)
23
+ formatter = logging.Formatter(
24
+ "[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s"
25
+ )
26
+ fh = logging.FileHandler(log_file)
27
+ fh.setFormatter(formatter)
28
+ sh = logging.StreamHandler()
29
+ sh.setFormatter(formatter)
30
+ log.setLevel(level)
31
+ log.addHandler(fh)
32
+ log.addHandler(sh)
33
+ return log
34
+
35
+ def get_current_time():
36
+ current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
37
+ return current_time
_utils/seg_eval.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ def iou_torch(inst1, inst2):
5
+ inter = torch.logical_and(inst1, inst2).sum().float()
6
+ union = torch.logical_or(inst1, inst2).sum().float()
7
+ if union == 0:
8
+ return torch.tensor(float('nan'))
9
+ return inter / union
10
+
11
+ def get_instances_torch(mask):
12
+ # 返回所有非背景的 instance mask(布尔型)
13
+ ids = torch.unique(mask)
14
+ return [(mask == i) for i in ids if i != 0]
15
+
16
+ def compute_instance_miou(pred_mask, gt_mask):
17
+ # pred_mask 和 gt_mask 都是 torch.Tensor, shape [H, W], 整数类型
18
+ pred_instances = get_instances_torch(pred_mask)
19
+ gt_instances = get_instances_torch(gt_mask)
20
+
21
+ ious = []
22
+ for gt in gt_instances:
23
+ best_iou = torch.tensor(0.0).to(pred_mask.device)
24
+ for pred in pred_instances:
25
+ i = iou_torch(pred, gt)
26
+ if i > best_iou:
27
+ best_iou = i
28
+ ious.append(best_iou)
29
+
30
+ # 处理空情况
31
+ if len(ious) == 0:
32
+ return torch.tensor(float('nan'))
33
+ return torch.nanmean(torch.stack(ious))
34
+
35
+ from torch import Tensor
36
+
37
+
38
+ def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
39
+ # Average of Dice coefficient for all batches, or for a single mask
40
+ assert input.size() == target.size()
41
+ assert input.dim() == 3 or not reduce_batch_first
42
+
43
+ sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
44
+
45
+ inter = 2 * (input * target).sum(dim=sum_dim)
46
+ sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
47
+ sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
48
+
49
+ dice = (inter + epsilon) / (sets_sum + epsilon)
50
+ return dice.mean()
51
+
52
+
53
+ def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
54
+ # Average of Dice coefficient for all classes
55
+ return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)
56
+
57
+
58
+ def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
59
+ # Dice loss (objective to minimize) between 0 and 1
60
+ fn = multiclass_dice_coeff if multiclass else dice_coeff
61
+ return 1 - fn(input, target, reduce_batch_first=True)
_utils/track_args.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import configargparse
2
+
3
+
4
+ def parse_train_args():
5
+ parser = configargparse.ArgumentParser(
6
+ formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
7
+ config_file_parser_class=configargparse.YAMLConfigFileParser,
8
+ allow_abbrev=False,
9
+ )
10
+ parser.add_argument(
11
+ "-c",
12
+ "--config",
13
+ default="_utils/example_config.yaml",
14
+ is_config_file=True,
15
+ help="config file path",
16
+ )
17
+ parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
18
+ parser.add_argument("-o", "--outdir", type=str, default="runs")
19
+ parser.add_argument("--name", type=str, help="Name to append to timestamp")
20
+ parser.add_argument("--timestamp", type=bool, default=True)
21
+ parser.add_argument(
22
+ "-m",
23
+ "--model",
24
+ type=str,
25
+ default="",
26
+ help="load this model at start (e.g. to continue training)",
27
+ )
28
+ parser.add_argument(
29
+ "--ndim", type=int, default=2, help="number of spatial dimensions"
30
+ )
31
+ parser.add_argument("-d", "--d_model", type=int, default=256)
32
+ parser.add_argument("-w", "--window", type=int, default=10)
33
+ parser.add_argument("--epochs", type=int, default=100)
34
+ parser.add_argument("--warmup_epochs", type=int, default=10)
35
+ parser.add_argument(
36
+ "--detection_folders",
37
+ type=str,
38
+ nargs="+",
39
+ default=["TRA"],
40
+ help=(
41
+ "Subfolders to search for detections. Defaults to `TRA`, which corresponds"
42
+ " to using only the GT."
43
+ ),
44
+ )
45
+ parser.add_argument("--downscale_temporal", type=int, default=1)
46
+ parser.add_argument("--downscale_spatial", type=int, default=1)
47
+ parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
48
+ parser.add_argument("--from_subfolder", action="store_true")
49
+ # parser.add_argument("--train_samples", type=int, default=50000)
50
+ parser.add_argument("--num_encoder_layers", type=int, default=6)
51
+ parser.add_argument("--num_decoder_layers", type=int, default=6)
52
+ parser.add_argument("--pos_embed_per_dim", type=int, default=32)
53
+ parser.add_argument("--feat_embed_per_dim", type=int, default=8)
54
+ parser.add_argument("--dropout", type=float, default=0.00)
55
+ parser.add_argument("--num_workers", type=int, default=4)
56
+ parser.add_argument("--batch_size", type=int, default=1)
57
+ parser.add_argument("--max_tokens", type=int, default=None)
58
+ parser.add_argument("--delta_cutoff", type=int, default=2)
59
+ parser.add_argument("--lr", type=float, default=1e-4)
60
+ parser.add_argument(
61
+ "--attn_positional_bias",
62
+ type=str,
63
+ choices=["rope", "bias", "none"],
64
+ default="rope",
65
+ )
66
+ parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
67
+ parser.add_argument("--attn_dist_mode", default="v0")
68
+ parser.add_argument("--mixedp", type=bool, default=True)
69
+ parser.add_argument("--dry", action="store_true")
70
+ parser.add_argument("--profile", action="store_true")
71
+ parser.add_argument(
72
+ "--features",
73
+ type=str,
74
+ choices=[
75
+ "none",
76
+ "regionprops",
77
+ "regionprops2",
78
+ "patch",
79
+ "patch_regionprops",
80
+ "wrfeat",
81
+ ],
82
+ default="wrfeat",
83
+ )
84
+ parser.add_argument(
85
+ "--causal_norm",
86
+ type=str,
87
+ choices=["none", "linear", "softmax", "quiet_softmax"],
88
+ default="quiet_softmax",
89
+ )
90
+ parser.add_argument("--div_upweight", type=float, default=2)
91
+
92
+ parser.add_argument("--augment", type=int, default=3)
93
+ parser.add_argument("--tracking_frequency", type=int, default=-1)
94
+
95
+ parser.add_argument("--sanity_dist", action="store_true")
96
+ parser.add_argument("--preallocate", type=bool, default=False)
97
+ parser.add_argument("--only_prechecks", action="store_true")
98
+ parser.add_argument(
99
+ "--compress", type=bool, default=True, help="compress dataset"
100
+ )
101
+
102
+
103
+ parser.add_argument("--seed", type=int, default=None)
104
+ parser.add_argument(
105
+ "--logger",
106
+ type=str,
107
+ default="tensorboard",
108
+ choices=["tensorboard", "wandb", "none"],
109
+ )
110
+ parser.add_argument("--wandb_project", type=str, default="trackastra")
111
+ parser.add_argument(
112
+ "--crop_size",
113
+ type=int,
114
+ # required=True,
115
+ nargs="+",
116
+ default=None,
117
+ help="random crop size for augmentation",
118
+ )
119
+ parser.add_argument(
120
+ "--weight_by_ndivs",
121
+ type=bool,
122
+ default=True,
123
+ help="Oversample windows that contain divisions",
124
+ )
125
+ parser.add_argument(
126
+ "--weight_by_dataset",
127
+ type=bool,
128
+ default=False,
129
+ help=(
130
+ "Inversely weight datasets by number of samples (to counter dataset size"
131
+ " imbalance)"
132
+ ),
133
+ )
134
+
135
+ args, unknown_args = parser.parse_known_args()
136
+
137
+ # # Hack to allow for --input_test
138
+ # allowed_unknown = ["input_test"]
139
+ # if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset(
140
+ # set(allowed_unknown)
141
+ # ):
142
+ # raise ValueError(f"Unknown args: {unknown_args}")
143
+
144
+ # pprint(vars(args))
145
+
146
+ # for backward compatibility
147
+ # if args.attn_positional_bias == "True":
148
+ # args.attn_positional_bias = "bias"
149
+ # elif args.attn_positional_bias == "False":
150
+ # args.attn_positional_bias = False
151
+
152
+ # if args.train_samples == 0:
153
+ # raise NotImplementedError(
154
+ # "--train_samples must be > 0, full dataset pass not supported."
155
+ # )
156
+
157
+ return args
app.py ADDED
@@ -0,0 +1,1638 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from gradio_bbox_annotator import BBoxAnnotator
3
+ from PIL import Image
4
+ import numpy as np
5
+ import torch
6
+ import os
7
+ import shutil
8
+ import time
9
+ import json
10
+ import uuid
11
+ from pathlib import Path
12
+ import tempfile
13
+ import zipfile
14
+ from skimage import measure
15
+ from matplotlib import cm
16
+ from glob import glob
17
+ from natsort import natsorted
18
+ from huggingface_hub import HfApi, upload_file
19
+ # import spaces
20
+
21
+ # ===== 导入三个推理模块 =====
22
+ from inference_seg import load_model as load_seg_model, run as run_seg
23
+ from inference_count import load_model as load_count_model, run as run_count
24
+ from inference_track import load_model as load_track_model, run as run_track
25
+
26
+ HF_TOKEN = os.getenv("HF_TOKEN")
27
+ DATASET_REPO = "phoebe777777/celltool_feedback"
28
+
29
+
30
+ # ===== 清理缓存目录 =====
31
+ print("===== clearing cache =====")
32
+ # cache_path = os.path.expanduser("~/.cache/")
33
+ cache_path = os.path.expanduser("~/.cache/huggingface/gradio")
34
+ if os.path.exists(cache_path):
35
+ try:
36
+ shutil.rmtree(cache_path)
37
+ # print("✅ Deleted ~/.cache/")
38
+ print("✅ Deleted ~/.cache/huggingface/gradio")
39
+ except:
40
+ pass
41
+
42
+ # ===== 全局模型变量 =====
43
+ SEG_MODEL = None
44
+ SEG_DEVICE = torch.device("cpu")
45
+
46
+ COUNT_MODEL = None
47
+ COUNT_DEVICE = torch.device("cpu")
48
+
49
+ TRACK_MODEL = None
50
+ TRACK_DEVICE = torch.device("cpu")
51
+
52
+ def load_all_models():
53
+ """启动时加载所有模型"""
54
+ global SEG_MODEL, SEG_DEVICE
55
+ global COUNT_MODEL, COUNT_DEVICE
56
+ global TRACK_MODEL, TRACK_DEVICE
57
+
58
+ print("\n" + "="*60)
59
+ print("📦 Loading Segmentation Model")
60
+ print("="*60)
61
+ SEG_MODEL, SEG_DEVICE = load_seg_model(use_box=False)
62
+
63
+ print("\n" + "="*60)
64
+ print("📦 Loading Counting Model")
65
+ print("="*60)
66
+ COUNT_MODEL, COUNT_DEVICE = load_count_model(use_box=False)
67
+
68
+ print("\n" + "="*60)
69
+ print("📦 Loading Tracking Model")
70
+ print("="*60)
71
+ TRACK_MODEL, TRACK_DEVICE = load_track_model(use_box=False)
72
+
73
+ print("\n" + "="*60)
74
+ print("✅ All Models Loaded Successfully")
75
+ print("="*60)
76
+
77
+ load_all_models()
78
+
79
+ # ===== 保存用户反馈 =====
80
+ DATASET_DIR = Path("solver_cache")
81
+ DATASET_DIR.mkdir(parents=True, exist_ok=True)
82
+
83
+ def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
84
+ """保存反馈到 Hugging Face Dataset"""
85
+
86
+ # 如果没有 token,回退到本地存储
87
+ if not HF_TOKEN:
88
+ print("⚠️ No HF_TOKEN found, using local storage")
89
+ save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
90
+ return
91
+
92
+ feedback_data = {
93
+ "query_id": query_id,
94
+ "feedback_type": feedback_type,
95
+ "feedback_text": feedback_text,
96
+ "image_path": img_path,
97
+ "bboxes": str(bboxes), # 转为字符串
98
+ "datetime": time.strftime("%Y-%m-%d %H:%M:%S"),
99
+ "timestamp": time.time()
100
+ }
101
+
102
+ try:
103
+ api = HfApi()
104
+
105
+ # 创建临时文件
106
+ filename = f"feedback_{query_id}_{int(time.time())}.json"
107
+
108
+ with open(filename, 'w', encoding='utf-8') as f:
109
+ json.dump(feedback_data, f, indent=2, ensure_ascii=False)
110
+
111
+ # 上传到 dataset
112
+ api.upload_file(
113
+ path_or_fileobj=filename,
114
+ path_in_repo=f"data/{filename}",
115
+ repo_id=DATASET_REPO,
116
+ repo_type="dataset",
117
+ token=HF_TOKEN
118
+ )
119
+
120
+ # 清理本地文件
121
+ os.remove(filename)
122
+
123
+ print(f"✅ Feedback saved to HF Dataset: {DATASET_REPO}")
124
+
125
+ except Exception as e:
126
+ print(f"⚠️ Failed to save to HF Dataset: {e}")
127
+ # 回退到本地存储
128
+ save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
129
+
130
+
131
+ def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
132
+ """保存用户反馈到JSON文件"""
133
+ feedback_data = {
134
+ "query_id": query_id,
135
+ "feedback_type": feedback_type,
136
+ "feedback_text": feedback_text,
137
+ "image": img_path,
138
+ "bboxes": bboxes,
139
+ "datetime": time.strftime("%Y%m%d_%H%M%S")
140
+ }
141
+ feedback_file = DATASET_DIR / query_id / "feedback.json"
142
+ feedback_file.parent.mkdir(parents=True, exist_ok=True)
143
+
144
+ if feedback_file.exists():
145
+ with feedback_file.open("r") as f:
146
+ existing = json.load(f)
147
+ if not isinstance(existing, list):
148
+ existing = [existing]
149
+ existing.append(feedback_data)
150
+ feedback_data = existing
151
+ else:
152
+ feedback_data = [feedback_data]
153
+
154
+ with feedback_file.open("w") as f:
155
+ json.dump(feedback_data, f, indent=4, ensure_ascii=False)
156
+
157
+ # ===== 辅助函数 =====
158
+ def parse_first_bbox(bboxes):
159
+ """解析第一个边界框"""
160
+ if not bboxes:
161
+ return None
162
+ b = bboxes[0]
163
+ if isinstance(b, dict):
164
+ x, y = float(b.get("x", 0)), float(b.get("y", 0))
165
+ w, h = float(b.get("width", 0)), float(b.get("height", 0))
166
+ return x, y, x + w, y + h
167
+ if isinstance(b, (list, tuple)) and len(b) >= 4:
168
+ return float(b[0]), float(b[1]), float(b[2]), float(b[3])
169
+ return None
170
+
171
+ def parse_bboxes(bboxes):
172
+ """解析所有边界框"""
173
+ if not bboxes:
174
+ return None
175
+
176
+ result = []
177
+ for b in bboxes:
178
+ if isinstance(b, dict):
179
+ x, y = float(b.get("x", 0)), float(b.get("y", 0))
180
+ w, h = float(b.get("width", 0)), float(b.get("height", 0))
181
+ result.append([x, y, x + w, y + h])
182
+ elif isinstance(b, (list, tuple)) and len(b) >= 4:
183
+ result.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])])
184
+
185
+ return result
186
+
187
+ def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
188
+ """将实例掩码转换为彩色图像"""
189
+ def hsv_to_rgb(h, s, v):
190
+ i = int(h * 6.0)
191
+ f = h * 6.0 - i
192
+ i = i % 6
193
+ p = v * (1 - s)
194
+ q = v * (1 - f * s)
195
+ t = v * (1 - (1 - f) * s)
196
+ if i == 0: r, g, b = v, t, p
197
+ elif i == 1: r, g, b = q, v, p
198
+ elif i == 2: r, g, b = p, v, t
199
+ elif i == 3: r, g, b = p, q, v
200
+ elif i == 4: r, g, b = t, p, v
201
+ else: r, g, b = v, p, q
202
+ return int(r * 255), int(g * 255), int(b * 255)
203
+
204
+ palette = [(0, 0, 0)]
205
+ for i in range(1, num_colors):
206
+ h = (i % num_colors) / float(num_colors)
207
+ palette.append(hsv_to_rgb(h, 1.0, 0.95))
208
+
209
+ palette_arr = np.array(palette, dtype=np.uint8)
210
+ color_idx = mask % num_colors
211
+ return palette_arr[color_idx]
212
+
213
+ # ===== 分割功能 =====
214
+ # @spaces.GPU
215
+ def segment_with_choice(use_box_choice, annot_value):
216
+ """分割主函数 - 每个实例不同颜色+轮廓"""
217
+ if annot_value is None or len(annot_value) < 1:
218
+ print("❌ No annotation input")
219
+ return None, None
220
+
221
+ img_path = annot_value[0]
222
+ bboxes = annot_value[1] if len(annot_value) > 1 else []
223
+
224
+ print(f"🖼️ Image path: {img_path}")
225
+ box_array = None
226
+ if use_box_choice == "Yes" and bboxes:
227
+ # box = parse_first_bbox(bboxes)
228
+ # if box:
229
+ # xmin, ymin, xmax, ymax = map(int, box)
230
+ # box_array = [[xmin, ymin, xmax, ymax]]
231
+ # print(f"📦 Using bounding box: {box_array}")
232
+ box = parse_bboxes(bboxes)
233
+ if box:
234
+ box_array = box
235
+ print(f"📦 Using bounding boxes: {box_array}")
236
+
237
+
238
+ # 运行分割模型
239
+ try:
240
+ mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
241
+ print("📏 mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask))
242
+ except Exception as e:
243
+ print(f"❌ Inference failed: {str(e)}")
244
+ return None, None
245
+
246
+ # 保存原始mask为TIF文件
247
+ temp_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tif")
248
+ mask_img = Image.fromarray(mask.astype(np.uint16))
249
+ mask_img.save(temp_mask_file.name)
250
+ print(f"💾 Original mask saved to: {temp_mask_file.name}")
251
+
252
+ # 读取原图
253
+ try:
254
+ img = Image.open(img_path)
255
+ print("📷 Image mode:", img.mode, "size:", img.size)
256
+ except Exception as e:
257
+ print(f"❌ Failed to open image: {e}")
258
+ return None, None
259
+
260
+ try:
261
+ img_rgb = img.convert("RGB").resize(mask.shape[::-1], resample=Image.BILINEAR)
262
+ img_np = np.array(img_rgb, dtype=np.float32)
263
+ if img_np.max() > 1.5:
264
+ img_np = img_np / 255.0
265
+ except Exception as e:
266
+ print(f"❌ Error in image conversion/resizing: {e}")
267
+ return None, None
268
+
269
+ mask_np = np.array(mask)
270
+ inst_mask = mask_np.astype(np.int32)
271
+ unique_ids = np.unique(inst_mask)
272
+ num_instances = len(unique_ids[unique_ids != 0])
273
+ print(f"✅ Instance IDs found: {unique_ids}, Total instances: {num_instances}")
274
+
275
+ if num_instances == 0:
276
+ print("⚠️ No instance found, returning dummy red image")
277
+ return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None
278
+
279
+ # ==== Color Overlay (每个实例一个颜色) ====
280
+ overlay = img_np.copy()
281
+ alpha = 0.5
282
+ # cmap = cm.get_cmap("hsv", num_instances + 1)
283
+
284
+ for inst_id in np.unique(inst_mask):
285
+ if inst_id == 0:
286
+ continue
287
+ binary_mask = (inst_mask == inst_id).astype(np.uint8)
288
+ # color = np.array(cmap(inst_id / (num_instances + 1))[:3]) # RGB only, ignore alpha
289
+ color = get_well_spaced_color(inst_id)
290
+ overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
291
+
292
+ # 绘制轮廓
293
+ contours = measure.find_contours(binary_mask, 0.5)
294
+ for contour in contours:
295
+ contour = contour.astype(np.int32)
296
+ # 确保坐标在范围内
297
+ valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1)
298
+ valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1)
299
+ overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # 黄色轮廓
300
+
301
+ overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
302
+
303
+ return Image.fromarray(overlay), temp_mask_file.name
304
+
305
+ # ===== 计数功能 =====
306
+ # @spaces.GPU
307
+ def count_cells_handler(use_box_choice, annot_value):
308
+ """Counting handler - supports bounding box, returns only density map"""
309
+ if annot_value is None or len(annot_value) < 1:
310
+ return None, "⚠️ Please provide an image."
311
+
312
+ image_path = annot_value[0]
313
+ bboxes = annot_value[1] if len(annot_value) > 1 else []
314
+
315
+ print(f"🖼️ Image path: {image_path}")
316
+ box_array = None
317
+ if use_box_choice == "Yes" and bboxes:
318
+ # box = parse_first_bbox(bboxes)
319
+ # if box:
320
+ # xmin, ymin, xmax, ymax = map(int, box)
321
+ # box_array = [[xmin, ymin, xmax, ymax]]
322
+ # print(f"📦 Using bounding box: {box_array}")
323
+ box = parse_bboxes(bboxes)
324
+ if box:
325
+ box_array = box
326
+ print(f"📦 Using bounding boxes: {box_array}")
327
+
328
+ try:
329
+ print(f"🔢 Counting - Image: {image_path}")
330
+
331
+ result = run_count(
332
+ COUNT_MODEL,
333
+ image_path,
334
+ box=box_array,
335
+ device=COUNT_DEVICE,
336
+ visualize=True
337
+ )
338
+
339
+ if 'error' in result:
340
+ return None, f"❌ Counting failed: {result['error']}"
341
+
342
+ count = result['count']
343
+ density_map = result['density_map']
344
+ # save density map as temp file
345
+ temp_density_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy")
346
+ np.save(temp_density_file.name, density_map)
347
+ print(f"💾 Density map saved to {temp_density_file.name}")
348
+
349
+
350
+ try:
351
+ img = Image.open(image_path)
352
+ print("📷 Image mode:", img.mode, "size:", img.size)
353
+ except Exception as e:
354
+ print(f"❌ Failed to open image: {e}")
355
+ return None, None
356
+
357
+ try:
358
+ img_rgb = img.convert("RGB").resize(density_map.shape[::-1], resample=Image.BILINEAR)
359
+ img_np = np.array(img_rgb, dtype=np.float32)
360
+ img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
361
+ if img_np.max() > 1.5:
362
+ img_np = img_np / 255.0
363
+ except Exception as e:
364
+ print(f"❌ Error in image conversion/resizing: {e}")
365
+ return None, None
366
+
367
+
368
+ # Normalize density map to [0, 1]
369
+ density_normalized = density_map.copy()
370
+ if density_normalized.max() > 0:
371
+ density_normalized = (density_normalized - density_normalized.min()) / (density_normalized.max() - density_normalized.min())
372
+
373
+ # Apply colormap
374
+ cmap = cm.get_cmap("jet")
375
+ alpha = 0.3
376
+ density_colored = cmap(density_normalized)[:, :, :3] # RGB only, ignore alpha
377
+
378
+ # Create overlay
379
+ overlay = img_np.copy()
380
+
381
+ # Blend only where density is significant (optional: threshold)
382
+ threshold = 0.01 # Only overlay where density > 1% of max
383
+ significant_mask = density_normalized > threshold
384
+
385
+ overlay[significant_mask] = (1 - alpha) * overlay[significant_mask] + alpha * density_colored[significant_mask]
386
+
387
+ # Clip and convert to uint8
388
+ overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
389
+
390
+
391
+
392
+
393
+
394
+ result_text = f"✅ Detected {round(count)} objects"
395
+ if use_box_choice == "Yes" and box:
396
+ result_text += f"\n📦 Using bounding box: {box_array}"
397
+
398
+
399
+ print(f"✅ Counting done - Count: {count:.1f}")
400
+
401
+ return Image.fromarray(overlay), temp_density_file.name, result_text
402
+
403
+ # return density_path, result_text
404
+
405
+ except Exception as e:
406
+ print(f"❌ Counting error: {e}")
407
+ import traceback
408
+ traceback.print_exc()
409
+ return None, f"❌ Counting failed: {str(e)}"
410
+
411
+ # ===== Tracking Functionality =====
412
+ def find_tif_dir(root_dir):
413
+ """Recursively find the first directory containing .tif files"""
414
+ for dirpath, _, filenames in os.walk(root_dir):
415
+ if '__MACOSX' in dirpath:
416
+ continue
417
+ if any(f.lower().endswith('.tif') for f in filenames):
418
+ return dirpath
419
+ return None
420
+
421
+ def is_valid_tiff(filepath):
422
+ """Check if a file is a valid TIFF image"""
423
+ try:
424
+ with Image.open(filepath) as img:
425
+ img.verify()
426
+ return True
427
+ except Exception as e:
428
+ return False
429
+
430
+ def find_valid_tif_dir(root_dir):
431
+ """Recursively find the first directory containing valid .tif files"""
432
+ for dirpath, dirnames, filenames in os.walk(root_dir):
433
+ if '__MACOSX' in dirpath:
434
+ continue
435
+
436
+ potential_tifs = [
437
+ os.path.join(dirpath, f)
438
+ for f in filenames
439
+ if f.lower().endswith(('.tif', '.tiff')) and not f.startswith('._')
440
+ ]
441
+
442
+ if not potential_tifs:
443
+ continue
444
+
445
+ valid_tifs = [f for f in potential_tifs if is_valid_tiff(f)]
446
+
447
+ if valid_tifs:
448
+ print(f"✅ Found {len(valid_tifs)} valid TIFF files in: {dirpath}")
449
+ return dirpath
450
+
451
+ return None
452
+
453
+ def create_ctc_results_zip(output_dir):
454
+ """
455
+ Create a ZIP file with CTC format results
456
+
457
+ Parameters:
458
+ -----------
459
+ output_dir : str
460
+ Directory containing tracking results (res_track.txt, etc.)
461
+
462
+ Returns:
463
+ --------
464
+ zip_path : str
465
+ Path to created ZIP file
466
+ """
467
+ # Create temp directory for ZIP
468
+ temp_zip_dir = tempfile.mkdtemp()
469
+ zip_filename = f"tracking_results_{time.strftime('%Y%m%d_%H%M%S')}.zip"
470
+ zip_path = os.path.join(temp_zip_dir, zip_filename)
471
+
472
+ print(f"📦 Creating results ZIP: {zip_path}")
473
+
474
+ # Create ZIP with all tracking results
475
+ with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
476
+ # Add all files from output directory
477
+ for root, dirs, files in os.walk(output_dir):
478
+ for file in files:
479
+ file_path = os.path.join(root, file)
480
+ arcname = os.path.relpath(file_path, output_dir)
481
+ zipf.write(file_path, arcname)
482
+ print(f" 📄 Added: {arcname}")
483
+
484
+ # Add a README with summary
485
+ readme_content = f"""Tracking Results Summary
486
+ ========================
487
+
488
+ Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}
489
+
490
+ Files:
491
+ ------
492
+ - res_track.txt: CTC format tracking data
493
+ Format: track_id start_frame end_frame parent_id
494
+
495
+ - Segmentation masks
496
+
497
+ For more information on CTC format:
498
+ http://celltrackingchallenge.net/
499
+ """
500
+ zipf.writestr("README.txt", readme_content)
501
+
502
+ print(f"✅ ZIP created: {zip_path} ({os.path.getsize(zip_path) / 1024:.1f} KB)")
503
+ return zip_path
504
+
505
+ # 使用更智能的颜色分配 - 让相邻的ID颜色差异更大
506
+ def get_well_spaced_color(track_id, num_colors=256):
507
+ """Generate well-spaced colors, using contrasting colors for adjacent IDs"""
508
+ # 使用质数跳跃来分散颜色
509
+ golden_ratio = 0.618033988749895
510
+ hue = (track_id * golden_ratio) % 1.0
511
+
512
+ # 使用高饱和度和明度
513
+ import colorsys
514
+ rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
515
+ return np.array(rgb)
516
+
517
+
518
+ def extract_first_frame(tif_dir):
519
+ """
520
+ Extract the first frame from a directory of TIF files
521
+
522
+ Returns:
523
+ --------
524
+ first_frame_path : str
525
+ Path to the first TIF frame
526
+ """
527
+ tif_files = natsorted(glob(os.path.join(tif_dir, "*.tif")) +
528
+ glob(os.path.join(tif_dir, "*.tiff")))
529
+ valid_tif_files = [f for f in tif_files
530
+ if not os.path.basename(f).startswith('._') and is_valid_tiff(f)]
531
+
532
+ if valid_tif_files:
533
+ return valid_tif_files[0]
534
+ return None
535
+
536
+ def create_tracking_visualization(tif_dir, output_dir, valid_tif_files):
537
+ """
538
+ Create an animated GIF/video showing tracked objects with consistent colors
539
+
540
+ Parameters:
541
+ -----------
542
+ tif_dir : str
543
+ Directory containing input TIF frames
544
+ output_dir : str
545
+ Directory containing tracking results (masks)
546
+ valid_tif_files : list
547
+ List of valid TIF file paths
548
+
549
+ Returns:
550
+ --------
551
+ video_path : str
552
+ Path to generated visualization (GIF or first frame)
553
+ """
554
+ import numpy as np
555
+ from matplotlib import colormaps
556
+ from skimage import measure
557
+ import tifffile
558
+
559
+ # Look for tracking mask files in output directory
560
+ # Common CTC formats: man_track*.tif, mask*.tif, or numbered masks
561
+ mask_files = natsorted(glob(os.path.join(output_dir, "mask*.tif")) +
562
+ glob(os.path.join(output_dir, "man_track*.tif")) +
563
+ glob(os.path.join(output_dir, "*.tif")))
564
+
565
+ if not mask_files:
566
+ print("⚠️ No mask files found in output directory")
567
+ # Return first frame as fallback
568
+ return valid_tif_files[0]
569
+
570
+ print(f"📊 Found {len(mask_files)} mask files")
571
+
572
+ # Create color map for consistent track IDs
573
+ # Use a colormap with many distinct colors
574
+ # try:
575
+ # cmap = colormaps.get_cmap("hsv")
576
+ # except:
577
+ # from matplotlib import cm
578
+ # cmap = cm.get_cmap("hsv")
579
+
580
+ frames = []
581
+ alpha = 0.3 # Transparency for overlay
582
+
583
+ # Process each frame
584
+ num_frames = min(len(valid_tif_files), len(mask_files))
585
+ for i in range(num_frames):
586
+ try:
587
+ # Load original image using tifffile (handles ZSTD compression)
588
+ try:
589
+ img_np = tifffile.imread(valid_tif_files[i])
590
+
591
+ # Normalize to [0, 1] range based on actual data type and values
592
+ if img_np.dtype == np.uint8:
593
+ img_np = img_np.astype(np.float32) / 255.0
594
+ elif img_np.dtype == np.uint16:
595
+ # Normalize uint16 to [0, 1] using actual min/max
596
+ img_min, img_max = img_np.min(), img_np.max()
597
+ if img_max > img_min:
598
+ img_np = (img_np.astype(np.float32) - img_min) / (img_max - img_min)
599
+ else:
600
+ img_np = img_np.astype(np.float32) / 65535.0
601
+ else:
602
+ # For float or other types, normalize based on actual range
603
+ img_np = img_np.astype(np.float32)
604
+ img_min, img_max = img_np.min(), img_np.max()
605
+ if img_max > img_min:
606
+ img_np = (img_np - img_min) / (img_max - img_min)
607
+ else:
608
+ img_np = np.clip(img_np, 0, 1)
609
+
610
+ # Convert to RGB if grayscale
611
+ if img_np.ndim == 2:
612
+ img_np = np.stack([img_np]*3, axis=-1)
613
+ img_np = img_np.astype(np.float32)
614
+ if img_np.max() > 1.5:
615
+ img_np = img_np / 255.0
616
+ except Exception as e:
617
+ print(f"⚠️ Error loading image frame {i}: {e}")
618
+ # Fallback to PIL
619
+ img = Image.open(valid_tif_files[i]).convert("RGB")
620
+ img_np = np.array(img, dtype=np.float32) / 255.0
621
+
622
+ # Load tracking mask using tifffile (handles ZSTD compression)
623
+ try:
624
+ mask = tifffile.imread(mask_files[i])
625
+ except Exception as e:
626
+ print(f"⚠️ Error loading mask frame {i}: {e}")
627
+ # Fallback to PIL
628
+ mask = np.array(Image.open(mask_files[i]))
629
+
630
+ # Resize mask to match image if needed
631
+ if mask.shape[:2] != img_np.shape[:2]:
632
+ from scipy.ndimage import zoom
633
+ zoom_factors = [img_np.shape[0] / mask.shape[0], img_np.shape[1] / mask.shape[1]]
634
+ mask = zoom(mask, zoom_factors, order=0).astype(mask.dtype)
635
+
636
+ # Create overlay
637
+ overlay = img_np.copy()
638
+
639
+ # Get unique track IDs (excluding background 0)
640
+ track_ids = np.unique(mask)
641
+ track_ids = track_ids[track_ids != 0]
642
+
643
+ # Color each tracked object
644
+ for track_id in track_ids:
645
+ # Create binary mask for this track
646
+ binary_mask = (mask == track_id)
647
+
648
+ # Get consistent color for this track ID
649
+ # color = np.array(cmap(int(track_id) % 256)[:3])
650
+ color = get_well_spaced_color(int(track_id))
651
+
652
+ # Blend color onto image
653
+ overlay[binary_mask] = (1 - alpha) * overlay[binary_mask] + alpha * color
654
+
655
+ # Draw contours (optional, adds yellow boundaries)
656
+ try:
657
+ contours = measure.find_contours(binary_mask.astype(np.uint8), 0.5)
658
+ for contour in contours:
659
+ contour = contour.astype(np.int32)
660
+ valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1)
661
+ valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1)
662
+ overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # Yellow contour
663
+ except:
664
+ pass # Skip contours if they fail
665
+
666
+ # Convert to uint8
667
+ overlay_uint8 = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
668
+ frames.append(Image.fromarray(overlay_uint8))
669
+
670
+ if i % 10 == 0 or i == num_frames - 1:
671
+ print(f" 📸 Processed frame {i+1}/{num_frames}")
672
+
673
+ except Exception as e:
674
+ print(f"⚠️ Error processing frame {i}: {e}")
675
+ import traceback
676
+ traceback.print_exc()
677
+ continue
678
+
679
+ if not frames:
680
+ print("⚠️ No frames were processed successfully")
681
+ return valid_tif_files[0]
682
+
683
+ # Save as animated GIF
684
+ try:
685
+ temp_gif = tempfile.NamedTemporaryFile(delete=False, suffix=".gif")
686
+ frames[0].save(
687
+ temp_gif.name,
688
+ save_all=True,
689
+ append_images=frames[1:],
690
+ duration=200, # 200ms per frame = 5fps
691
+ loop=0
692
+ )
693
+ temp_gif.close() # Close the file handle
694
+ print(f"✅ Created tracking visualization GIF: {temp_gif.name}")
695
+ print(f" Size: {os.path.getsize(temp_gif.name)} bytes, Frames: {len(frames)}")
696
+ return temp_gif.name
697
+ except Exception as e:
698
+ print(f"⚠️ Failed to create GIF: {e}")
699
+ import traceback
700
+ traceback.print_exc()
701
+ # Return first frame as static image fallback
702
+ try:
703
+ temp_img = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
704
+ frames[0].save(temp_img.name)
705
+ temp_img.close()
706
+ return temp_img.name
707
+ except:
708
+ return valid_tif_files[0]
709
+
710
+ # @spaces.GPU
711
+ def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
712
+ """
713
+ 支持 ZIP 压缩包上传的 Tracking 处理函数 - 支持首帧边界框
714
+
715
+ Parameters:
716
+ -----------
717
+ use_box_choice : str
718
+ "Yes" or "No" - 是否使用边界框
719
+ first_frame_annot : tuple or None
720
+ (image_path, bboxes) from BBoxAnnotator, only used if user annotated first frame
721
+ zip_file_obj : File
722
+ Uploaded ZIP file containing TIF sequence
723
+ """
724
+ if zip_file_obj is None:
725
+ return None, "⚠️ 请上传包含视频帧的压缩包 (.zip)", None, None
726
+
727
+ temp_dir = None
728
+ output_temp_dir = None
729
+
730
+ try:
731
+ # Parse bounding box if provided
732
+ box_array = None
733
+ if use_box_choice == "Yes" and first_frame_annot is not None:
734
+ if isinstance(first_frame_annot, (list, tuple)) and len(first_frame_annot) > 1:
735
+ bboxes = first_frame_annot[1]
736
+ if bboxes:
737
+ # box = parse_first_bbox(bboxes)
738
+ # if box:
739
+ # xmin, ymin, xmax, ymax = map(int, box)
740
+ # box_array = [[xmin, ymin, xmax, ymax]]
741
+ # print(f"📦 Using bounding box: {box_array}")
742
+ box = parse_bboxes(bboxes)
743
+ if box:
744
+ box_array = box
745
+ print(f"📦 Using bounding boxes: {box_array}")
746
+
747
+ # Extract input ZIP
748
+ temp_dir = tempfile.mkdtemp()
749
+ print(f"\n📦 Extracting to temporary directory: {temp_dir}")
750
+
751
+ with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref:
752
+ extracted_count = 0
753
+ skipped_count = 0
754
+
755
+ for member in zip_ref.namelist():
756
+ basename = os.path.basename(member)
757
+
758
+ if ('__MACOSX' in member or
759
+ basename.startswith('._') or
760
+ basename.startswith('.DS_Store') or
761
+ member.endswith('/')):
762
+ skipped_count += 1
763
+ continue
764
+
765
+ try:
766
+ zip_ref.extract(member, temp_dir)
767
+ extracted_count += 1
768
+ if basename.lower().endswith(('.tif', '.tiff')):
769
+ print(f"📄 Extracted TIFF: {basename}")
770
+ except Exception as e:
771
+ print(f"⚠️ Failed to extract {member}: {e}")
772
+
773
+ print(f"\n📊 Extracted: {extracted_count} files, Skipped: {skipped_count} files")
774
+
775
+ # Find valid TIFF directory
776
+ tif_dir = find_valid_tif_dir(temp_dir)
777
+
778
+ if tif_dir is None:
779
+ return None, "❌ Did not find valid TIF directory", None, None
780
+
781
+ # Validate TIFF files
782
+ tif_files = natsorted(glob(os.path.join(tif_dir, "*.tif")) +
783
+ glob(os.path.join(tif_dir, "*.tiff")))
784
+ valid_tif_files = [f for f in tif_files
785
+ if not os.path.basename(f).startswith('._') and is_valid_tiff(f)]
786
+
787
+ if len(valid_tif_files) == 0:
788
+ return None, "❌ Did not find valid TIF files", None, None
789
+
790
+ print(f"📈 Using {len(valid_tif_files)} TIF files")
791
+
792
+ # Store paths for later visualization
793
+ first_frame_path = valid_tif_files[0]
794
+
795
+ # Create temporary output directory for CTC results
796
+ output_temp_dir = tempfile.mkdtemp()
797
+ print(f"💾 CTC-format results will be saved to: {output_temp_dir}")
798
+
799
+ # Run tracking with optional bounding box
800
+ result = run_track(
801
+ TRACK_MODEL,
802
+ video_dir=tif_dir,
803
+ box=box_array, # Pass bounding box if specified
804
+ device=TRACK_DEVICE,
805
+ output_dir=output_temp_dir
806
+ )
807
+
808
+ if 'error' in result:
809
+ return None, f"❌ Tracking failed: {result['error']}", None, None
810
+
811
+ # Create visualization video of tracked objects
812
+ print("\n🎬 Creating tracking visualization...")
813
+ try:
814
+ tracking_video = create_tracking_visualization(
815
+ tif_dir,
816
+ output_temp_dir,
817
+ valid_tif_files
818
+ )
819
+ except Exception as e:
820
+ print(f"⚠️ Failed to create visualization: {e}")
821
+ import traceback
822
+ traceback.print_exc()
823
+ # Fallback to first frame if visualization fails
824
+ try:
825
+ tracking_video = Image.open(first_frame_path)
826
+ except:
827
+ tracking_video = None
828
+
829
+ # Create downloadable ZIP with results
830
+ try:
831
+ results_zip = create_ctc_results_zip(output_temp_dir)
832
+ except Exception as e:
833
+ print(f"⚠️ Failed to create ZIP: {e}")
834
+ results_zip = None
835
+
836
+ bbox_info = ""
837
+ if box_array:
838
+ bbox_info = f"\n🔲 Using bounding box: [{box_array[0][0]}, {box_array[0][1]}, {box_array[0][2]}, {box_array[0][3]}]"
839
+
840
+ result_text = f"""✅ Tracking completed!
841
+
842
+ 🖼️ Processed frames: {len(valid_tif_files)}{bbox_info}
843
+
844
+ 📥 Click the button below to download CTC-format results
845
+ The results include:
846
+ - res_track.txt (CTC-format tracking data)
847
+ - Other tracking-related files
848
+ - README.txt (Results description)
849
+ """
850
+
851
+ if use_box_choice == "Yes" and box:
852
+ result_text += f"\n📦 Using bounding box: {box_array}"
853
+
854
+ print(f"\n✅ Tracking completed")
855
+
856
+ # Clean up input temp directory (keep output temp for download)
857
+ if temp_dir:
858
+ try:
859
+ shutil.rmtree(temp_dir)
860
+ print(f"🗑️ Cleared input temp directory")
861
+ except:
862
+ pass
863
+
864
+ return results_zip, result_text, gr.update(visible=True), tracking_video
865
+
866
+ except zipfile.BadZipFile:
867
+ return None, "❌ Not a valid ZIP file", None, None
868
+ except Exception as e:
869
+ import traceback
870
+ traceback.print_exc()
871
+
872
+ # Clean up on error
873
+ for d in [temp_dir, output_temp_dir]:
874
+ if d:
875
+ try:
876
+ shutil.rmtree(d)
877
+ except:
878
+ pass
879
+ return None, f"❌ Tracking failed: {str(e)}", None, None
880
+
881
+
882
+
883
+ # ===== 示例图像 =====
884
+ example_images_seg = [f for f in glob("example_imgs/seg/*")]
885
+ # ["example_imgs/seg/003_img.png", "example_imgs/seg/1977_Well_F-5_Field_1.png"]
886
+ example_images_cnt = [f for f in glob("example_imgs/cnt/*")]
887
+ example_tracking_zips = [f for f in glob("example_imgs/tra/*.zip")]
888
+
889
+ # ===== Gradio UI =====
890
+ with gr.Blocks(
891
+ title="Microscopy Analysis Suite",
892
+ theme=gr.themes.Soft(),
893
+ css="""
894
+ .tabs button {
895
+ font-size: 18px !important;
896
+ font-weight: 600 !important;
897
+ padding: 12px 20px !important;
898
+ }
899
+ .uniform-height {
900
+ height: 500px !important;
901
+ display: flex !important;
902
+ align-items: center !important;
903
+ justify-content: center !important;
904
+ }
905
+
906
+ .uniform-height img,
907
+ .uniform-height canvas {
908
+ max-height: 500px !important;
909
+ object-fit: contain !important;
910
+ }
911
+
912
+ /* 强制密度图容器和图片高度 */
913
+ #density_map_output {
914
+ height: 500px !important;
915
+ }
916
+
917
+ #density_map_output .image-container {
918
+ height: 500px !important;
919
+ }
920
+
921
+ #density_map_output img {
922
+ height: 480px !important;
923
+ width: auto !important;
924
+ max-width: 90% !important;
925
+ object-fit: contain !important;
926
+ }
927
+ """
928
+ ) as demo:
929
+ gr.Markdown(
930
+ """
931
+ # 🔬 Microscopy Image Analysis Suite
932
+
933
+ Supporting three key tasks:
934
+ - 🎨 **Segmentation**: Instance segmentation of microscopic objects
935
+ - 🔢 **Counting**: Counting microscopic objects based on density maps
936
+ - 🎬 **Tracking**: Tracking microscopic objects in video sequences
937
+ """
938
+ )
939
+
940
+ # 全局状态
941
+ current_query_id = gr.State(str(uuid.uuid4()))
942
+ user_uploaded_examples = gr.State(example_images_seg.copy()) # 初始化时包含原始示例
943
+
944
+ with gr.Tabs():
945
+ # ===== Tab 1: Segmentation =====
946
+ with gr.Tab("🎨 Segmentation"):
947
+ gr.Markdown("## Instance Segmentation of Microscopic Objects")
948
+ gr.Markdown(
949
+ """
950
+ **Instructions:**
951
+ 1. Upload an image or select an example image (supports various formats: .png, .jpg, .tif)
952
+ 2. (Optional) Specify a target object with a bounding box and select "Yes", or click "Run Segmentation" directly
953
+ 3. Click "Run Segmentation"
954
+ 4. View the segmentation results, download the original predicted mask (.tif format); if needed, click "Clear Selection" to choose a new image
955
+
956
+ 🤘 Rate and submit feedback to help us improve the model!
957
+ """
958
+ )
959
+
960
+ with gr.Row():
961
+ with gr.Column(scale=1):
962
+ annotator = BBoxAnnotator(
963
+ label="🖼️ Upload Image (Optional: Provide a Bounding Box)",
964
+ categories=["cell"],
965
+ )
966
+
967
+ # Example Images Gallery
968
+ example_gallery = gr.Gallery(
969
+ label="📁 Example Image Gallery",
970
+ columns=len(example_images_seg),
971
+ rows=1,
972
+ height=120,
973
+ object_fit="cover",
974
+ show_download_button=False
975
+ )
976
+
977
+
978
+ with gr.Row():
979
+ use_box_radio = gr.Radio(
980
+ choices=["Yes", "No"],
981
+ value="No",
982
+ label="🔲 Specify Bounding Box?"
983
+ )
984
+ with gr.Row():
985
+ run_seg_btn = gr.Button("▶️ Run Segmentation", variant="primary", size="lg")
986
+ clear_btn = gr.Button("🔄 Clear Selection", variant="secondary")
987
+
988
+ # Upload Example Image
989
+ image_uploader = gr.Image(
990
+ label="➕ Upload New Example Image to Gallery",
991
+ type="filepath"
992
+ )
993
+
994
+
995
+ with gr.Column(scale=2):
996
+ seg_output = gr.Image(
997
+ type="pil",
998
+ label="📸 Segmentation Result",
999
+ elem_classes="uniform-height"
1000
+ )
1001
+
1002
+ # Download Original Prediction
1003
+ download_mask_btn = gr.File(
1004
+ label="📥 Download Original Prediction (.tif format)",
1005
+ visible=True,
1006
+ height=40,
1007
+ )
1008
+
1009
+ # Satisfaction Rating
1010
+ score_slider = gr.Slider(
1011
+ minimum=1,
1012
+ maximum=5,
1013
+ step=1,
1014
+ value=5,
1015
+ label="🌟 Satisfaction Rating (1-5)"
1016
+ )
1017
+
1018
+ # Feedback Textbox
1019
+ feedback_box = gr.Textbox(
1020
+ placeholder="Please enter your feedback...",
1021
+ lines=2,
1022
+ label="💬 Feedback"
1023
+ )
1024
+
1025
+ # Submit Button
1026
+ submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary")
1027
+
1028
+ feedback_status = gr.Textbox(
1029
+ label="✅ Submission Status",
1030
+ lines=1,
1031
+ visible=False
1032
+ )
1033
+
1034
+ # 绑定事件: 运行分割
1035
+ run_seg_btn.click(
1036
+ fn=segment_with_choice,
1037
+ inputs=[use_box_radio, annotator],
1038
+ outputs=[seg_output, download_mask_btn]
1039
+ )
1040
+
1041
+ # 清空按钮事件
1042
+ clear_btn.click(
1043
+ fn=lambda: None,
1044
+ inputs=None,
1045
+ outputs=annotator
1046
+ )
1047
+
1048
+ # 初始化Gallery显示
1049
+ demo.load(
1050
+ fn=lambda: example_images_seg.copy(),
1051
+ outputs=example_gallery
1052
+ )
1053
+
1054
+ # 绑定事件: 上传示例图片
1055
+ def add_to_gallery(img_path, current_imgs):
1056
+ if not img_path:
1057
+ return current_imgs
1058
+ try:
1059
+ if img_path not in current_imgs:
1060
+ current_imgs.append(img_path)
1061
+ return current_imgs
1062
+ except:
1063
+ return current_imgs
1064
+
1065
+ image_uploader.change(
1066
+ fn=add_to_gallery,
1067
+ inputs=[image_uploader, user_uploaded_examples],
1068
+ outputs=user_uploaded_examples
1069
+ ).then(
1070
+ fn=lambda imgs: imgs,
1071
+ inputs=user_uploaded_examples,
1072
+ outputs=example_gallery
1073
+ )
1074
+
1075
+ # 绑定事件: 点击Gallery加载
1076
+ def load_from_gallery(evt: gr.SelectData, all_imgs):
1077
+ if evt.index is not None and evt.index < len(all_imgs):
1078
+ return all_imgs[evt.index]
1079
+ return None
1080
+
1081
+ example_gallery.select(
1082
+ fn=load_from_gallery,
1083
+ inputs=user_uploaded_examples,
1084
+ outputs=annotator
1085
+ )
1086
+
1087
+ # 绑定事件: 提交反馈
1088
+ def submit_user_feedback(query_id, score, comment, annot_val):
1089
+ try:
1090
+ img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
1091
+ bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else []
1092
+
1093
+ # save_feedback(
1094
+ # query_id=query_id,
1095
+ # feedback_type=f"score_{int(score)}",
1096
+ # feedback_text=comment,
1097
+ # img_path=img_path,
1098
+ # bboxes=bboxes
1099
+ # )
1100
+ # 使用 HF 存储
1101
+ save_feedback_to_hf(
1102
+ query_id=query_id,
1103
+ feedback_type=f"score_{int(score)}",
1104
+ feedback_text=comment,
1105
+ img_path=img_path,
1106
+ bboxes=bboxes
1107
+ )
1108
+ return "✅ Feedback submitted, thank you!", gr.update(visible=True)
1109
+ except Exception as e:
1110
+ return f"❌ Submission failed: {str(e)}", gr.update(visible=True)
1111
+
1112
+ submit_feedback_btn.click(
1113
+ fn=submit_user_feedback,
1114
+ inputs=[current_query_id, score_slider, feedback_box, annotator],
1115
+ outputs=[feedback_status, feedback_status]
1116
+ )
1117
+
1118
+ # ===== Tab 2: Counting =====
1119
+ with gr.Tab("🔢 Counting"):
1120
+ gr.Markdown("## Microscopy Object Counting Analysis")
1121
+ gr.Markdown(
1122
+ """
1123
+ **Usage Instructions:**
1124
+ 1. Upload an image or select an example image (supports multiple formats: .png, .jpg, .tif)
1125
+ 2. (Optional) Specify a target object with a bounding box and select "Yes", or click "Run Counting" directly
1126
+ 3. Click "Run Counting"
1127
+ 4. View the density map, download the original prediction (.npy format); if needed, click "Clear Selection" to choose a new image to run
1128
+
1129
+ 🤘 Rate and submit feedback to help us improve the model!
1130
+ """
1131
+ )
1132
+
1133
+ with gr.Row():
1134
+ with gr.Column(scale=1):
1135
+ count_annotator = BBoxAnnotator(
1136
+ label="🖼️ Upload Image (Optional: Provide a Bounding Box)",
1137
+ categories=["cell"],
1138
+ )
1139
+
1140
+ # Example gallery with "add" functionality
1141
+ with gr.Row():
1142
+ count_example_gallery = gr.Gallery(
1143
+ label="📁 Example Image Gallery",
1144
+ columns=len(example_images_cnt),
1145
+ rows=1,
1146
+ object_fit="cover",
1147
+ height=120,
1148
+ value=example_images_cnt.copy(), # Initialize with examples
1149
+ show_download_button=False
1150
+ )
1151
+
1152
+
1153
+ with gr.Row():
1154
+ count_use_box_radio = gr.Radio(
1155
+ choices=["Yes", "No"],
1156
+ value="No",
1157
+ label="🔲 Specify Bounding Box?"
1158
+ )
1159
+
1160
+ with gr.Row():
1161
+ count_btn = gr.Button("▶️ Run Counting", variant="primary", size="lg")
1162
+ clear_btn = gr.Button("🔄 Clear Selection", variant="secondary")
1163
+
1164
+ # Add button to upload new examples
1165
+ with gr.Row():
1166
+ count_image_uploader = gr.File(
1167
+ label="➕ Add Example Image to Gallery",
1168
+ file_types=["image"],
1169
+ type="filepath"
1170
+ )
1171
+
1172
+
1173
+ with gr.Column(scale=2):
1174
+ count_output = gr.Image(
1175
+ label="📸 Density Map",
1176
+ type="filepath",
1177
+ elem_id="density_map_output"
1178
+
1179
+ )
1180
+ count_status = gr.Textbox(
1181
+ label="📊 Statistics",
1182
+ lines=2
1183
+ )
1184
+ download_density_btn = gr.File(
1185
+ label="📥 Download Original Prediction (.npy format)",
1186
+ visible=True
1187
+ )
1188
+
1189
+ # Satisfaction rating
1190
+ score_slider = gr.Slider(
1191
+ minimum=1,
1192
+ maximum=5,
1193
+ step=1,
1194
+ value=5,
1195
+ label="🌟 Satisfaction Rating (1-5)"
1196
+ )
1197
+
1198
+ # Feedback textbox
1199
+ feedback_box = gr.Textbox(
1200
+ placeholder="Please enter your feedback...",
1201
+ lines=2,
1202
+ label="💬 Feedback"
1203
+ )
1204
+
1205
+ # Submit button
1206
+ submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary")
1207
+
1208
+ feedback_status = gr.Textbox(
1209
+ label="✅ Submission Status",
1210
+ lines=1,
1211
+ visible=False
1212
+ )
1213
+
1214
+ # State for managing gallery images
1215
+ count_user_examples = gr.State(example_images_cnt.copy())
1216
+
1217
+ # Function to add image to gallery
1218
+ def add_to_count_gallery(new_img_file, current_imgs):
1219
+ """Add uploaded image to gallery"""
1220
+ if new_img_file is None:
1221
+ return current_imgs, current_imgs
1222
+
1223
+ try:
1224
+ # Add new image path to list
1225
+ if new_img_file not in current_imgs:
1226
+ current_imgs.append(new_img_file)
1227
+ print(f"✅ Added image to gallery: {new_img_file}")
1228
+ except Exception as e:
1229
+ print(f"⚠️ Failed to add image: {e}")
1230
+
1231
+ return current_imgs, current_imgs
1232
+
1233
+ # When user uploads a new image file
1234
+ count_image_uploader.upload(
1235
+ fn=add_to_count_gallery,
1236
+ inputs=[count_image_uploader, count_user_examples],
1237
+ outputs=[count_user_examples, count_example_gallery]
1238
+ )
1239
+
1240
+ # When user selects from gallery, load into annotator
1241
+ def load_from_count_gallery(evt: gr.SelectData, all_imgs):
1242
+ """Load selected image from gallery into annotator"""
1243
+ if evt.index is not None and evt.index < len(all_imgs):
1244
+ selected_img = all_imgs[evt.index]
1245
+ print(f"📸 Loading image from gallery: {selected_img}")
1246
+ return selected_img
1247
+ return None
1248
+
1249
+ count_example_gallery.select(
1250
+ fn=load_from_count_gallery,
1251
+ inputs=count_user_examples,
1252
+ outputs=count_annotator
1253
+ )
1254
+
1255
+ # Run counting
1256
+ count_btn.click(
1257
+ fn=count_cells_handler,
1258
+ inputs=[count_use_box_radio, count_annotator],
1259
+ outputs=[count_output, download_density_btn, count_status]
1260
+ )
1261
+
1262
+ # 清空按钮事件
1263
+ clear_btn.click(
1264
+ fn=lambda: None,
1265
+ inputs=None,
1266
+ outputs=count_annotator
1267
+ )
1268
+
1269
+ # 绑定事件: 提交反馈
1270
+ def submit_user_feedback(query_id, score, comment, annot_val):
1271
+ try:
1272
+ img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
1273
+ bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else []
1274
+
1275
+ # save_feedback(
1276
+ # query_id=query_id,
1277
+ # feedback_type=f"score_{int(score)}",
1278
+ # feedback_text=comment,
1279
+ # img_path=img_path,
1280
+ # bboxes=bboxes
1281
+ # )
1282
+ # 使用 HF 存储
1283
+ save_feedback_to_hf(
1284
+ query_id=query_id,
1285
+ feedback_type=f"score_{int(score)}",
1286
+ feedback_text=comment,
1287
+ img_path=img_path,
1288
+ bboxes=bboxes
1289
+ )
1290
+ return "✅ Feedback submitted successfully, thank you!", gr.update(visible=True)
1291
+ except Exception as e:
1292
+ return f"❌ Submission failed: {str(e)}", gr.update(visible=True)
1293
+
1294
+ submit_feedback_btn.click(
1295
+ fn=submit_user_feedback,
1296
+ inputs=[current_query_id, score_slider, feedback_box, annotator],
1297
+ outputs=[feedback_status, feedback_status]
1298
+ )
1299
+
1300
+ # ===== Tab 3: Tracking =====
1301
+ with gr.Tab("🎬 Tracking"):
1302
+ gr.Markdown("## Microscopy Object Video Tracking - Supports ZIP Upload")
1303
+ gr.Markdown(
1304
+ """
1305
+ **Instructions:**
1306
+ 1. Upload a ZIP file or select from the example library. The ZIP should contain a sequence of TIF images named in chronological order (e.g., t000.tif, t001.tif...)
1307
+ 2. (Optional) Specify a target object with a bounding box on the first frame and select "Yes", or click "Run Tracking" directly
1308
+ 3. Click "Run Tracking"
1309
+ 4. Download the CTC format results; if needed, click "Clear Selection" to choose a new ZIP file to run
1310
+
1311
+ 🤘 Rate and submit feedback to help us improve the model!
1312
+
1313
+ """
1314
+ )
1315
+
1316
+ with gr.Row():
1317
+ with gr.Column(scale=1):
1318
+ track_zip_upload = gr.File(
1319
+ label="📦 Upload Image Sequence in ZIP File",
1320
+ file_types=[".zip"]
1321
+ )
1322
+
1323
+ # First frame annotation for bounding box
1324
+ track_first_frame_annotator = BBoxAnnotator(
1325
+ label="🖼️ (Optional) First Frame Bounding Box Annotation",
1326
+ categories=["cell"],
1327
+ visible=False, # Hidden initially
1328
+ )
1329
+
1330
+ # Example ZIP gallery
1331
+ track_example_gallery = gr.Gallery(
1332
+ label="📁 Example Video Gallery (Click to Select)",
1333
+ columns=10,
1334
+ rows=1,
1335
+ height=120,
1336
+ object_fit="contain",
1337
+ show_download_button=False
1338
+ )
1339
+
1340
+ with gr.Row():
1341
+ track_use_box_radio = gr.Radio(
1342
+ choices=["Yes", "No"],
1343
+ value="No",
1344
+ label="🔲 Specify Bounding Box?"
1345
+ )
1346
+
1347
+ with gr.Row():
1348
+ track_btn = gr.Button("▶️ Run Tracking", variant="primary", size="lg")
1349
+ clear_btn = gr.Button("🔄 Clear Selection", variant="secondary")
1350
+
1351
+ # Add to gallery button
1352
+ track_gallery_upload = gr.File(
1353
+ label="➕ Add ZIP to Example Gallery",
1354
+ file_types=[".zip"],
1355
+ type="filepath"
1356
+ )
1357
+
1358
+ with gr.Column(scale=2):
1359
+ track_first_frame_preview = gr.Image(
1360
+ label="📸 Tracking Visualization",
1361
+ type="filepath",
1362
+ # height=400,
1363
+ elem_classes="uniform-height",
1364
+ interactive=False
1365
+ )
1366
+
1367
+ track_output = gr.Textbox(
1368
+ label="📊 Tracking Information",
1369
+ lines=8,
1370
+ interactive=False
1371
+ )
1372
+
1373
+ track_download = gr.File(
1374
+ label="📥 Download Tracking Results (CTC Format)",
1375
+ visible=False
1376
+ )
1377
+
1378
+ # Satisfaction rating
1379
+ score_slider = gr.Slider(
1380
+ minimum=1,
1381
+ maximum=5,
1382
+ step=1,
1383
+ value=5,
1384
+ label="🌟 Satisfaction Rating (1-5)"
1385
+ )
1386
+
1387
+ # Feedback textbox
1388
+ feedback_box = gr.Textbox(
1389
+ placeholder="Please enter your feedback...",
1390
+ lines=2,
1391
+ label="💬 Feedback"
1392
+ )
1393
+
1394
+ # Submit button
1395
+ submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary")
1396
+
1397
+ feedback_status = gr.Textbox(
1398
+ label="✅ Submission Status",
1399
+ lines=1,
1400
+ visible=False
1401
+ )
1402
+
1403
+ # State for tracking examples
1404
+ track_user_examples = gr.State(example_tracking_zips.copy())
1405
+
1406
+ # Function to get preview image from ZIP
1407
+ def get_zip_preview(zip_path):
1408
+ """Extract first frame from ZIP for gallery preview"""
1409
+ try:
1410
+ temp_dir = tempfile.mkdtemp()
1411
+ with zipfile.ZipFile(zip_path, 'r') as zip_ref:
1412
+ for member in zip_ref.namelist():
1413
+ basename = os.path.basename(member)
1414
+ if ('__MACOSX' not in member and
1415
+ not basename.startswith('._') and
1416
+ basename.lower().endswith(('.tif', '.tiff', '.png', '.jpg'))):
1417
+ zip_ref.extract(member, temp_dir)
1418
+ extracted_path = os.path.join(temp_dir, member)
1419
+
1420
+ # Load and normalize for preview
1421
+ import tifffile
1422
+ import numpy as np
1423
+
1424
+ img_np = tifffile.imread(extracted_path)
1425
+ if img_np.dtype == np.uint16:
1426
+ img_min, img_max = img_np.min(), img_np.max()
1427
+ if img_max > img_min:
1428
+ img_np = ((img_np.astype(np.float32) - img_min) / (img_max - img_min) * 255).astype(np.uint8)
1429
+
1430
+ if img_np.ndim == 2:
1431
+ img_np = np.stack([img_np]*3, axis=-1)
1432
+
1433
+ # Save preview
1434
+ preview_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
1435
+ Image.fromarray(img_np).save(preview_path.name)
1436
+ return preview_path.name
1437
+ except:
1438
+ pass
1439
+ return None
1440
+
1441
+ # Initialize gallery with previews
1442
+ def init_tracking_gallery():
1443
+ """Create preview images for ZIP examples"""
1444
+ previews = []
1445
+ for zip_path in example_tracking_zips:
1446
+ if os.path.exists(zip_path):
1447
+ preview = get_zip_preview(zip_path)
1448
+ if preview:
1449
+ previews.append(preview)
1450
+ return previews
1451
+
1452
+ # Load gallery on startup
1453
+ demo.load(
1454
+ fn=init_tracking_gallery,
1455
+ outputs=track_example_gallery
1456
+ )
1457
+
1458
+ # Add ZIP to gallery
1459
+ def add_zip_to_gallery(zip_path, current_zips):
1460
+ if not zip_path:
1461
+ return current_zips, track_example_gallery
1462
+ try:
1463
+ if zip_path not in current_zips:
1464
+ current_zips.append(zip_path)
1465
+ print(f"✅ Added ZIP to gallery: {zip_path}")
1466
+ # Regenerate previews
1467
+ previews = []
1468
+ for zp in current_zips:
1469
+ preview = get_zip_preview(zp)
1470
+ if preview:
1471
+ previews.append(preview)
1472
+ return current_zips, previews
1473
+ except Exception as e:
1474
+ print(f"⚠️ Error: {e}")
1475
+ return current_zips, []
1476
+
1477
+ track_gallery_upload.upload(
1478
+ fn=add_zip_to_gallery,
1479
+ inputs=[track_gallery_upload, track_user_examples],
1480
+ outputs=[track_user_examples, track_example_gallery]
1481
+ )
1482
+
1483
+ # Select ZIP from gallery
1484
+ def load_zip_from_gallery(evt: gr.SelectData, all_zips):
1485
+ if evt.index is not None and evt.index < len(all_zips):
1486
+ selected_zip = all_zips[evt.index]
1487
+ print(f"📁 Selected ZIP from gallery: {selected_zip}")
1488
+ return selected_zip
1489
+ return None
1490
+
1491
+ track_example_gallery.select(
1492
+ fn=load_zip_from_gallery,
1493
+ inputs=track_user_examples,
1494
+ outputs=track_zip_upload
1495
+ )
1496
+
1497
+ # Load first frame when ZIP is uploaded
1498
+ def load_first_frame_for_annotation(zip_file_obj):
1499
+ '''Load and normalize first frame from ZIP for annotation'''
1500
+ if zip_file_obj is None:
1501
+ return None, gr.update(visible=False)
1502
+
1503
+ import tifffile
1504
+ import numpy as np
1505
+
1506
+ try:
1507
+ temp_dir = tempfile.mkdtemp()
1508
+ with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref:
1509
+ for member in zip_ref.namelist():
1510
+ basename = os.path.basename(member)
1511
+ if ('__MACOSX' not in member and
1512
+ not basename.startswith('._') and
1513
+ basename.lower().endswith(('.tif', '.tiff'))):
1514
+ zip_ref.extract(member, temp_dir)
1515
+
1516
+ tif_dir = find_valid_tif_dir(temp_dir)
1517
+ if tif_dir:
1518
+ first_frame = extract_first_frame(tif_dir)
1519
+ if first_frame:
1520
+ # Load and normalize the first frame
1521
+ try:
1522
+ img_np = tifffile.imread(first_frame)
1523
+
1524
+ # Normalize to [0, 255] uint8 range for display
1525
+ if img_np.dtype == np.uint8:
1526
+ pass # Already uint8
1527
+ elif img_np.dtype == np.uint16:
1528
+ # Normalize uint16 using actual min/max
1529
+ img_min, img_max = img_np.min(), img_np.max()
1530
+ if img_max > img_min:
1531
+ img_np = ((img_np.astype(np.float32) - img_min) / (img_max - img_min) * 255).astype(np.uint8)
1532
+ else:
1533
+ img_np = (img_np.astype(np.float32) / 65535.0 * 255).astype(np.uint8)
1534
+ else:
1535
+ # Float or other types
1536
+ img_np = img_np.astype(np.float32)
1537
+ img_min, img_max = img_np.min(), img_np.max()
1538
+ if img_max > img_min:
1539
+ img_np = ((img_np - img_min) / (img_max - img_min) * 255).astype(np.uint8)
1540
+ else:
1541
+ img_np = np.clip(img_np * 255, 0, 255).astype(np.uint8)
1542
+
1543
+ # Convert to RGB if grayscale
1544
+ if img_np.ndim == 2:
1545
+ img_np = np.stack([img_np]*3, axis=-1)
1546
+ elif img_np.ndim == 3 and img_np.shape[2] > 3:
1547
+ img_np = img_np[:, :, :3]
1548
+
1549
+ # Save normalized image to temp file
1550
+ temp_img = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
1551
+ Image.fromarray(img_np).save(temp_img.name)
1552
+
1553
+ print(f"✅ Loaded and normalized first frame: {first_frame}")
1554
+ print(f" Original dtype: {tifffile.imread(first_frame).dtype}")
1555
+ print(f" Normalized to uint8 RGB for annotation")
1556
+
1557
+ return temp_img.name, gr.update(visible=True)
1558
+ except Exception as e:
1559
+ print(f"⚠️ Error normalizing first frame: {e}")
1560
+ import traceback
1561
+ traceback.print_exc()
1562
+ # Fallback to original file
1563
+ return first_frame, gr.update(visible=True)
1564
+ except Exception as e:
1565
+ print(f"⚠️ Error loading first frame: {e}")
1566
+ import traceback
1567
+ traceback.print_exc()
1568
+ return None, gr.update(visible=False)
1569
+
1570
+ # Load first frame when ZIP is uploaded
1571
+ track_zip_upload.change(
1572
+ fn=load_first_frame_for_annotation,
1573
+ inputs=track_zip_upload,
1574
+ outputs=[track_first_frame_annotator, track_first_frame_annotator]
1575
+ )
1576
+
1577
+ # Run tracking
1578
+ track_btn.click(
1579
+ fn=track_video_handler,
1580
+ inputs=[track_use_box_radio, track_first_frame_annotator, track_zip_upload],
1581
+ outputs=[track_download, track_output, track_download, track_first_frame_preview]
1582
+ )
1583
+
1584
+ # 清空按钮事件
1585
+ clear_btn.click(
1586
+ fn=lambda: None,
1587
+ inputs=None,
1588
+ outputs=track_first_frame_annotator
1589
+ )
1590
+
1591
+ # 绑定事件: 提交反馈
1592
+ def submit_user_feedback(query_id, score, comment, annot_val):
1593
+ try:
1594
+ img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
1595
+ bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else []
1596
+
1597
+ # save_feedback(
1598
+ # query_id=query_id,
1599
+ # feedback_type=f"score_{int(score)}",
1600
+ # feedback_text=comment,
1601
+ # img_path=img_path,
1602
+ # bboxes=bboxes
1603
+ # )
1604
+ # 使用 HF 存储
1605
+ save_feedback_to_hf(
1606
+ query_id=query_id,
1607
+ feedback_type=f"score_{int(score)}",
1608
+ feedback_text=comment,
1609
+ img_path=img_path,
1610
+ bboxes=bboxes
1611
+ )
1612
+ return "✅ Feedback submitted successfully, thank you!", gr.update(visible=True)
1613
+ except Exception as e:
1614
+ return f"❌ Submission failed: {str(e)}", gr.update(visible=True)
1615
+
1616
+ submit_feedback_btn.click(
1617
+ fn=submit_user_feedback,
1618
+ inputs=[current_query_id, score_slider, feedback_box, annotator],
1619
+ outputs=[feedback_status, feedback_status]
1620
+ )
1621
+
1622
+ gr.Markdown(
1623
+ """
1624
+ ---
1625
+ ### 💡 Technical Details
1626
+
1627
+ **MicroscopyMatching** - A general-purpose microscopy image analysis toolkit based on Stable Diffusion
1628
+ """
1629
+ )
1630
+
1631
+ if __name__ == "__main__":
1632
+ demo.queue().launch(
1633
+ server_name="0.0.0.0",
1634
+ server_port=7860,
1635
+ share=False,
1636
+ ssr_mode=False,
1637
+ show_error=True,
1638
+ )
config.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass, field
2
+ from pathlib import Path
3
+ from typing import Dict, List
4
+
5
+
6
+ @dataclass
7
+ class RunConfig:
8
+ # Guiding text prompt
9
+ prompt: str = "<task-prompt>"
10
+ # Whether to use Stable Diffusion v2.1
11
+ sd_2_1: bool = False
12
+ # Which token indices to alter with attend-and-excite
13
+ token_indices: List[int] = field(default_factory=lambda: [2,5])
14
+ # Which random seeds to use when generating
15
+ seeds: List[int] = field(default_factory=lambda: [42])
16
+ # Path to save all outputs to
17
+ output_path: Path = Path('./outputs')
18
+ # Number of denoising steps
19
+ n_inference_steps: int = 50
20
+ # Text guidance scale
21
+ guidance_scale: float = 7.5
22
+ # Number of denoising steps to apply attend-and-excite
23
+ max_iter_to_alter: int = 25
24
+ # Resolution of UNet to compute attention maps over
25
+ attention_res: int = 16
26
+ # Whether to run standard SD or attend-and-excite
27
+ run_standard_sd: bool = False
28
+ # Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in
29
+ thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8})
30
+ # Scale factor for updating the denoised latent z_t
31
+ scale_factor: int = 20
32
+ # Start and end values used for scaling the scale factor - decays linearly with the denoising timestep
33
+ scale_range: tuple = field(default_factory=lambda: (1.0, 0.5))
34
+ # Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token
35
+ smooth_attentions: bool = True
36
+ # Standard deviation for the Gaussian smoothing
37
+ sigma: float = 0.5
38
+ # Kernel size for the Gaussian smoothing
39
+ kernel_size: int = 3
40
+ # Whether to save cross attention maps for the final results
41
+ save_cross_attention_maps: bool = False
42
+
43
+ def __post_init__(self):
44
+ self.output_path.mkdir(exist_ok=True, parents=True)
counting.py ADDED
@@ -0,0 +1,340 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # stable diffusion x loca
2
+ import os
3
+ import pprint
4
+ from typing import Any, List, Optional
5
+ import argparse
6
+ from huggingface_hub import hf_hub_download
7
+ import pyrallis
8
+ from pytorch_lightning.utilities.types import STEP_OUTPUT
9
+ import torch
10
+ import os
11
+ from PIL import Image
12
+ import numpy as np
13
+ from config import RunConfig
14
+ from _utils import attn_utils_new as attn_utils
15
+ from _utils.attn_utils import AttentionStore
16
+ from _utils.misc_helper import *
17
+ import torch.nn.functional as F
18
+ import matplotlib.pyplot as plt
19
+ import cv2
20
+ import warnings
21
+ from pytorch_lightning.callbacks import ModelCheckpoint
22
+ warnings.filterwarnings("ignore", category=UserWarning)
23
+ import pytorch_lightning as pl
24
+ from _utils.load_models import load_stable_diffusion_model
25
+ from models.model import Counting_with_SD_features_loca as Counting
26
+ from pytorch_lightning.loggers import WandbLogger
27
+ from models.enc_model.loca_args import get_argparser as loca_get_argparser
28
+ from models.enc_model.loca import build_model as build_loca_model
29
+ import time
30
+ import torchvision.transforms as T
31
+ import skimage.io as io
32
+
33
+ SCALE = 1
34
+
35
+
36
+ class CountingModule(pl.LightningModule):
37
+ def __init__(self, use_box=True):
38
+ super().__init__()
39
+ self.use_box = use_box
40
+ self.config = RunConfig() # config for stable diffusion
41
+ self.initialize_model()
42
+
43
+
44
+ def initialize_model(self):
45
+
46
+ # load loca model
47
+ loca_args = loca_get_argparser().parse_args()
48
+ self.loca_model = build_loca_model(loca_args)
49
+ # weights = torch.load("ckpt/loca_few_shot.pt")["model"]
50
+ # weights = {k.replace("module","") : v for k, v in weights.items()}
51
+ # self.loca_model.load_state_dict(weights, strict=False)
52
+ # del weights
53
+
54
+ self.counting_adapter = Counting(scale_factor=SCALE)
55
+ # if os.path.isfile(self.args.adapter_weight):
56
+ # adapter_weight = torch.load(self.args.adapter_weight,map_location=torch.device('cpu'))
57
+ # self.counting_adapter.load_state_dict(adapter_weight, strict=False)
58
+
59
+ ### load stable diffusion and its controller
60
+ self.stable = load_stable_diffusion_model(config=self.config)
61
+ self.noise_scheduler = self.stable.scheduler
62
+ self.controller = AttentionStore(max_size=64)
63
+ attn_utils.register_attention_control(self.stable, self.controller)
64
+ attn_utils.register_hier_output(self.stable)
65
+
66
+ ##### initialize token_emb #####
67
+ placeholder_token = "<task-prompt>"
68
+ self.task_token = "repetitive objects"
69
+ # Add the placeholder token in tokenizer
70
+ num_added_tokens = self.stable.tokenizer.add_tokens(placeholder_token)
71
+ if num_added_tokens == 0:
72
+ raise ValueError(
73
+ f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
74
+ " `placeholder_token` that is not already in the tokenizer."
75
+ )
76
+ try:
77
+ task_embed_from_pretrain = hf_hub_download(
78
+ repo_id="phoebe777777/111",
79
+ filename="task_embed.pth",
80
+ token=None,
81
+ force_download=False
82
+ )
83
+ placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token)
84
+ self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer))
85
+
86
+ token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data
87
+ token_embeds[placeholder_token_id] = task_embed_from_pretrain
88
+ except:
89
+ initializer_token = "count"
90
+ token_ids = self.stable.tokenizer.encode(initializer_token, add_special_tokens=False)
91
+ # Check if initializer_token is a single token or a sequence of tokens
92
+ if len(token_ids) > 1:
93
+ raise ValueError("The initializer token must be a single token.")
94
+
95
+ initializer_token_id = token_ids[0]
96
+ placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token)
97
+
98
+ self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer))
99
+
100
+ token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data
101
+ token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
102
+
103
+ # others
104
+ self.placeholder_token = placeholder_token
105
+ self.placeholder_token_id = placeholder_token_id
106
+
107
+
108
+ def move_to_device(self, device):
109
+ self.stable.to(device)
110
+ if self.loca_model is not None and self.counting_adapter is not None:
111
+ self.loca_model.to(device)
112
+ self.counting_adapter.to(device)
113
+ self.to(device)
114
+
115
+ def forward(self, data_path, box=None):
116
+ filename = data_path.split("/")[-1]
117
+ img = Image.open(data_path).convert("RGB")
118
+ width, height = img.size
119
+ input_image = T.Compose([T.ToTensor(), T.Resize((512, 512))])(img)
120
+ input_image_stable = input_image - 0.5
121
+ input_image = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_image)
122
+ if box is not None:
123
+ boxes = torch.tensor(box) / torch.tensor([width, height, width, height]) * 512 # xyxy, normalized
124
+ assert self.use_box == True
125
+ else:
126
+ boxes = torch.tensor([[100,100,130,130], [200,200,250,250]], dtype=torch.float32) # dummy box
127
+ assert self.use_box == False
128
+
129
+ # move to device
130
+ input_image = input_image.unsqueeze(0).to(self.device)
131
+ boxes = boxes.unsqueeze(0).to(self.device)
132
+ input_image_stable = input_image_stable.unsqueeze(0).to(self.device)
133
+
134
+
135
+
136
+ latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach()
137
+ latents = latents * 0.18215
138
+ # Sample noise that we'll add to the latents
139
+ noise = torch.randn_like(latents)
140
+ timesteps = torch.tensor([20], device=latents.device).long()
141
+ noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
142
+ input_ids_ = self.stable.tokenizer(
143
+ self.placeholder_token + " repetitive objects",
144
+ # "object",
145
+ padding="max_length",
146
+ truncation=True,
147
+ max_length=self.stable.tokenizer.model_max_length,
148
+ return_tensors="pt",
149
+ )
150
+ input_ids = input_ids_["input_ids"].to(self.device)
151
+ attention_mask = input_ids_["attention_mask"].to(self.device)
152
+ encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0]
153
+
154
+ input_image = input_image.to(self.device)
155
+ boxes = boxes.to(self.device)
156
+
157
+
158
+ task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id)
159
+ if self.use_box:
160
+ loca_out = self.loca_model.forward_before_reg(input_image, boxes)
161
+ loca_feature_bf_regression = loca_out["feature_bf_regression"]
162
+ adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768]
163
+ if task_loc_idx.shape[0] == 0:
164
+ encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位
165
+ else:
166
+ encoder_hidden_states[0,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位
167
+
168
+ # Predict the noise residual
169
+ noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states)
170
+ noise_pred = noise_pred.sample
171
+ attention_store = self.controller.attention_store
172
+
173
+
174
+ attention_maps = []
175
+ exemplar_attention_maps = []
176
+ exemplar_attention_maps1 = []
177
+ exemplar_attention_maps2 = []
178
+ exemplar_attention_maps3 = []
179
+
180
+ cross_self_task_attn_maps = []
181
+ cross_self_exe_attn_maps = []
182
+
183
+ # only use 64x64 self-attention
184
+ self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
185
+ prompts=[self.config.prompt], # 这里要改么
186
+ attention_store=self.controller,
187
+ res=64,
188
+ from_where=("up", "down"),
189
+ is_cross=False,
190
+ select=0
191
+ )
192
+ self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096]
193
+ prompts=[self.config.prompt], # 这里要改么
194
+ attention_store=self.controller,
195
+ res=32,
196
+ from_where=("up", "down"),
197
+ is_cross=False,
198
+ select=0
199
+ )
200
+ self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096]
201
+ prompts=[self.config.prompt], # 这里要改么
202
+ attention_store=self.controller,
203
+ res=16,
204
+ from_where=("up", "down"),
205
+ is_cross=False,
206
+ select=0
207
+ )
208
+
209
+ # cross attention
210
+ for res in [32, 16]:
211
+ attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
212
+ prompts=[self.config.prompt], # 这里要改么
213
+ attention_store=self.controller,
214
+ res=res,
215
+ from_where=("up", "down"),
216
+ is_cross=True,
217
+ select=0
218
+ )
219
+
220
+ task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
221
+ attention_maps.append(task_attn_)
222
+ if self.use_box:
223
+ exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn
224
+ exemplar_attention_maps.append(exemplar_attns)
225
+ else:
226
+ exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0)
227
+ exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0)
228
+ exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0)
229
+ exemplar_attention_maps1.append(exemplar_attns1)
230
+ exemplar_attention_maps2.append(exemplar_attns2)
231
+ exemplar_attention_maps3.append(exemplar_attns3)
232
+
233
+
234
+ scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))]
235
+ attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))])
236
+ task_attn_64 = torch.mean(attns, dim=0, keepdim=True)
237
+ cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64)
238
+ cross_self_task_attn_maps.append(cross_self_task_attn)
239
+
240
+ if self.use_box:
241
+ scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))]
242
+ attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))])
243
+ exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True)
244
+
245
+ cross_self_exe_attn = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64)
246
+ cross_self_exe_attn_maps.append(cross_self_exe_attn)
247
+ else:
248
+ scale_factors = [(64 // exemplar_attention_maps1[i].shape[-1]) for i in range(len(exemplar_attention_maps1))]
249
+ attns = torch.cat([F.interpolate(exemplar_attention_maps1[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps1))])
250
+ exemplar_attn_64_1 = torch.mean(attns, dim=0, keepdim=True)
251
+
252
+ scale_factors = [(64 // exemplar_attention_maps2[i].shape[-1]) for i in range(len(exemplar_attention_maps2))]
253
+ attns = torch.cat([F.interpolate(exemplar_attention_maps2[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps2))])
254
+ exemplar_attn_64_2 = torch.mean(attns, dim=0, keepdim=True)
255
+
256
+ scale_factors = [(64 // exemplar_attention_maps3[i].shape[-1]) for i in range(len(exemplar_attention_maps3))]
257
+ attns = torch.cat([F.interpolate(exemplar_attention_maps3[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps3))])
258
+ exemplar_attn_64_3 = torch.mean(attns, dim=0, keepdim=True)
259
+
260
+
261
+ cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64)
262
+ cross_self_task_attn_maps.append(cross_self_task_attn)
263
+
264
+ # if self.args.merge_exemplar == "average":
265
+ cross_self_exe_attn1 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_1)
266
+ cross_self_exe_attn2 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_2)
267
+ cross_self_exe_attn3 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_3)
268
+ exemplar_attn_64 = (exemplar_attn_64_1 + exemplar_attn_64_2 + exemplar_attn_64_3) / 3
269
+ cross_self_exe_attn = (cross_self_exe_attn1 + cross_self_exe_attn2 + cross_self_exe_attn3) / 3
270
+
271
+ exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6)
272
+
273
+ attn_stack = [exemplar_attn_64 / 2, cross_self_exe_attn / 2, exemplar_attn_64, cross_self_exe_attn]
274
+ attn_stack = torch.cat(attn_stack, dim=1)
275
+
276
+ if not self.use_box:
277
+
278
+ # cross_self_exe_attn_np = cross_self_exe_attn.detach().squeeze().cpu().numpy()
279
+ # boxes = gen_dummy_boxes(cross_self_exe_attn_np, max_boxes=1)
280
+ # boxes = boxes.to(self.device)
281
+
282
+ loca_out = self.loca_model.forward_before_reg(input_image, boxes)
283
+ loca_feature_bf_regression = loca_out["feature_bf_regression"]
284
+ attn_out = self.loca_model.forward_reg(loca_out, attn_stack, feature_list[-1])
285
+ pred_density = attn_out["pred"].squeeze().cpu().numpy()
286
+ pred_cnt = pred_density.sum().item()
287
+
288
+ # resize pred_density to original image size
289
+ pred_density_rsz = cv2.resize(pred_density, (width, height), interpolation=cv2.INTER_CUBIC)
290
+ pred_density_rsz = pred_density_rsz / pred_density_rsz.sum() * pred_cnt
291
+
292
+ return pred_density_rsz, pred_cnt
293
+
294
+
295
+ def inference(data_path, box=None, save_path="./example_imgs", visualize=False):
296
+ if box is not None:
297
+ use_box = True
298
+ else:
299
+ use_box = False
300
+ model = CountingModule(use_box=use_box)
301
+ load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_cnt.pth"), strict=True)
302
+ model.eval()
303
+ with torch.no_grad():
304
+ density_map, cnt = model(data_path, box)
305
+
306
+ if visualize:
307
+ img = io.imread(data_path)
308
+ if len(img.shape) == 3 and img.shape[2] > 3:
309
+ img = img[:,:,:3]
310
+ if len(img.shape) == 2:
311
+ img = np.stack([img]*3, axis=-1)
312
+ img_show = img.squeeze()
313
+ density_map_show = density_map.squeeze()
314
+ os.makedirs(save_path, exist_ok=True)
315
+ filename = data_path.split("/")[-1]
316
+ img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show))
317
+ fig, ax = plt.subplots(1,2, figsize=(12,6))
318
+ ax[0].imshow(img_show)
319
+ ax[0].axis('off')
320
+ ax[0].set_title(f"Input image")
321
+ ax[1].imshow(img_show)
322
+ ax[1].imshow(density_map_show, cmap='jet', alpha=0.5) # Overlay density map with some transparency
323
+ ax[1].axis('off')
324
+ ax[1].set_title(f"Predicted density map, count: {cnt:.1f}")
325
+ plt.tight_layout()
326
+ plt.savefig(os.path.join(save_path, filename.split(".")[0]+"_cnt.png"), dpi=300)
327
+ plt.close()
328
+ return density_map
329
+
330
+ def main():
331
+
332
+ inference(
333
+ data_path = "example_imgs/1977_Well_F-5_Field_1.png",
334
+ # box=[[150, 60, 183, 87]],
335
+ save_path = "./example_imgs",
336
+ visualize = True
337
+ )
338
+
339
+ if __name__ == "__main__":
340
+ main()
example_imgs/cnt/047cell.png ADDED

Git LFS Details

  • SHA256: 3c9fc3d2ab7beecb16d850b1ef82d70a7f7011051d0199f866bc31c42c296d42
  • Pointer size: 130 Bytes
  • Size of remote file: 72.8 kB
example_imgs/cnt/62_10.png ADDED

Git LFS Details

  • SHA256: b93c916a81eaec1a3511b9379fa293c026bbe74977bc21fc7666a83c92d3b122
  • Pointer size: 130 Bytes
  • Size of remote file: 91.5 kB
example_imgs/cnt/6800-17000_GTEX-XQ3S_Adipose-Subcutaneous.png ADDED

Git LFS Details

  • SHA256: 467319789c5b5b6c370a23c126c33044a841a115cf24b79f75106b5521cd5c44
  • Pointer size: 130 Bytes
  • Size of remote file: 79.9 kB
example_imgs/seg/003_img.png ADDED

Git LFS Details

  • SHA256: 41515cf5d7405135db4656c2cc61b59ab341143bfbee952b44a9542944e8528f
  • Pointer size: 131 Bytes
  • Size of remote file: 302 kB
example_imgs/seg/1-23 [Scan I08].png ADDED

Git LFS Details

  • SHA256: a96dfccdd794a95c9907b0eedecbd53dee078943d9a3dcdb43e11a36d34f5a1f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.42 MB
example_imgs/seg/10X_B2_Tile-15.aligned.png ADDED

Git LFS Details

  • SHA256: e8dce16565ccfb055438b0b65d9e70b5be6cc36c61a964eed53d7ec782b5afa3
  • Pointer size: 132 Bytes
  • Size of remote file: 1.52 MB
example_imgs/seg/1977_Well_F-5_Field_1.png ADDED

Git LFS Details

  • SHA256: 145a99e724048ed40db7843e57a1d93cd2e1f6e221d167a29b732740d6302c52
  • Pointer size: 132 Bytes
  • Size of remote file: 2.43 MB
example_imgs/seg/200972823[5179]_RhoGGG_YAP_TAZ [200972823 Well K6 Field #2].png ADDED

Git LFS Details

  • SHA256: 56bd7a8df07d66ff5f8dac67aa116efe0869f6c46d9ce77e595535a6acd60ae9
  • Pointer size: 132 Bytes
  • Size of remote file: 1.39 MB
example_imgs/seg/A172_Phase_C7_1_00d00h00m_1.png ADDED

Git LFS Details

  • SHA256: f57430b87923f5de9a5799cc84016aeb5d99cd5068481a9fedae2a68fa9bba43
  • Pointer size: 131 Bytes
  • Size of remote file: 159 kB
example_imgs/seg/JE2NileRed_oilp22_PMP_101220_011_NR.png ADDED

Git LFS Details

  • SHA256: bdf31a4eab7826435407f2f88bfeee8f95c2b04d8f579cf6281b7f5838195b03
  • Pointer size: 130 Bytes
  • Size of remote file: 64.4 kB
example_imgs/seg/OpenTest_031.png ADDED

Git LFS Details

  • SHA256: 973ecd4ca18c650d630491c1f3531ba4ff20c12a37728dc79f279b26651d0c82
  • Pointer size: 131 Bytes
  • Size of remote file: 966 kB
example_imgs/seg/X_24.png ADDED

Git LFS Details

  • SHA256: 514b2df4bdcdd1d09d1f032284a5c2aaa0572d2f1ec148b256e4bbf5d68eb3c7
  • Pointer size: 131 Bytes
  • Size of remote file: 102 kB
example_imgs/seg/exp_A01_G002_0001.oir.png ADDED

Git LFS Details

  • SHA256: 9c22531659320908a688da277b7f67b70aafb450e035f56e3962ebfd3423140f
  • Pointer size: 132 Bytes
  • Size of remote file: 1.69 MB
example_imgs/tra/tracking_test_sequence.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bda69434e3de8103c98313777640acd35fc7501eec4b1528456304142b18797f
3
+ size 10392163
example_imgs/tra/tracking_test_sequence2.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:120cc2a75a4dd571b8f8ee7ea363a9b82a2b4c516376ccf4f287b6864d2dd576
3
+ size 2288296
inference_count.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference_count.py
2
+ # 计数模型推理模块 - 独立版本
3
+
4
+ import torch
5
+ import numpy as np
6
+ from PIL import Image
7
+ import matplotlib.pyplot as plt
8
+ import tempfile
9
+ import os
10
+ from huggingface_hub import hf_hub_download
11
+ from counting import CountingModule
12
+
13
+ MODEL = None
14
+ DEVICE = torch.device("cpu")
15
+
16
+ def load_model(use_box=False):
17
+ """
18
+ 加载计数模型
19
+
20
+ Args:
21
+ use_box: 是否使用边界框
22
+
23
+ Returns:
24
+ model: 加载的模型
25
+ device: 设备
26
+ """
27
+ global MODEL, DEVICE
28
+
29
+ try:
30
+ print("🔄 Loading counting model...")
31
+
32
+ # 初始化模型
33
+ MODEL = CountingModule(use_box=use_box)
34
+
35
+ # 从 Hugging Face Hub 下载权重
36
+ ckpt_path = hf_hub_download(
37
+ repo_id="phoebe777777/111",
38
+ filename="microscopy_matching_cnt.pth",
39
+ token=None,
40
+ force_download=False
41
+ )
42
+
43
+ print(f"✅ Checkpoint downloaded: {ckpt_path}")
44
+
45
+ # 加载权重
46
+ MODEL.load_state_dict(
47
+ torch.load(ckpt_path, map_location="cpu"),
48
+ strict=True
49
+ )
50
+ MODEL.eval()
51
+
52
+ if torch.cuda.is_available():
53
+ DEVICE = torch.device("cuda")
54
+ MODEL.move_to_device(DEVICE)
55
+ print("✅ Model moved to CUDA")
56
+ else:
57
+ DEVICE = torch.device("cpu")
58
+ MODEL.move_to_device(DEVICE)
59
+ print("✅ Model on CPU")
60
+
61
+ print("✅ Counting model loaded successfully")
62
+ return MODEL, DEVICE
63
+
64
+ except Exception as e:
65
+ print(f"❌ Error loading counting model: {e}")
66
+ import traceback
67
+ traceback.print_exc()
68
+ return None, torch.device("cpu")
69
+
70
+
71
+ @torch.no_grad()
72
+ def run(model, img_path, box=None, device="cpu", visualize=True):
73
+ """
74
+ 运行计数推理
75
+
76
+ Args:
77
+ model: 计数模型
78
+ img_path: 图像路径
79
+ box: 边界框 [[x1, y1, x2, y2], ...] 或 None
80
+ device: 设备
81
+ visualize: 是否生成可视化
82
+
83
+ Returns:
84
+ result_dict: {
85
+ 'density_map': numpy array,
86
+ 'count': float,
87
+ 'visualized_path': str (如果 visualize=True)
88
+ }
89
+ """
90
+ print("DEVICE:", device)
91
+ model.move_to_device(device)
92
+ model.eval()
93
+ if box is not None:
94
+ use_box = True
95
+ else:
96
+ use_box = False
97
+ model.use_box = use_box
98
+
99
+ if model is None:
100
+ return {
101
+ 'density_map': None,
102
+ 'count': 0,
103
+ 'visualized_path': None,
104
+ 'error': 'Model not loaded'
105
+ }
106
+
107
+ try:
108
+ print(f"🔄 Running counting inference on {img_path}")
109
+
110
+ # 运行推理 (调用你的模型的 forward 方法)
111
+ with torch.no_grad():
112
+ density_map, count = model(img_path, box)
113
+
114
+ print(f"✅ Counting result: {count:.1f} objects")
115
+
116
+ result = {
117
+ 'density_map': density_map,
118
+ 'count': count,
119
+ 'visualized_path': None
120
+ }
121
+
122
+ # 可视化
123
+ # if visualize:
124
+ # viz_path = visualize_result(img_path, density_map, count)
125
+ # result['visualized_path'] = viz_path
126
+
127
+ return result
128
+
129
+ except Exception as e:
130
+ print(f"❌ Counting inference error: {e}")
131
+ import traceback
132
+ traceback.print_exc()
133
+ return {
134
+ 'density_map': None,
135
+ 'count': 0,
136
+ 'visualized_path': None,
137
+ 'error': str(e)
138
+ }
139
+
140
+
141
+ def visualize_result(image_path, density_map, count):
142
+ """
143
+ 可视化计数结果 (与你原来的可视化代码一致)
144
+
145
+ Args:
146
+ image_path: 原始图像路径
147
+ density_map: 密度图 (numpy array)
148
+ count: 计数值
149
+
150
+ Returns:
151
+ output_path: 可视化结果的临时文件路径
152
+ """
153
+ try:
154
+ import skimage.io as io
155
+
156
+ # 读取原始图像
157
+ img = io.imread(image_path)
158
+
159
+ # 处理不同格式的图像
160
+ if len(img.shape) == 3 and img.shape[2] > 3:
161
+ img = img[:, :, :3]
162
+ if len(img.shape) == 2:
163
+ img = np.stack([img]*3, axis=-1)
164
+
165
+ # 归一化显示
166
+ img_show = img.squeeze()
167
+ density_map_show = density_map.squeeze()
168
+
169
+ # 归一化图像
170
+ img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
171
+
172
+ # 创建可视化 (与你原来的代码一致)
173
+ fig, ax = plt.subplots(figsize=(8, 6))
174
+
175
+ # 右图: 密度图叠加
176
+ ax.imshow(img_show)
177
+ ax.imshow(density_map_show, cmap='jet', alpha=0.5)
178
+ ax.axis('off')
179
+ # ax.set_title(f"Predicted density map, count: {count:.1f}")
180
+
181
+ plt.tight_layout()
182
+
183
+ # 保存到临时文件
184
+ temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
185
+ plt.savefig(temp_file.name, dpi=300)
186
+ plt.close()
187
+
188
+ print(f"✅ Visualization saved to {temp_file.name}")
189
+ return temp_file.name
190
+
191
+ except Exception as e:
192
+ print(f"❌ Visualization error: {e}")
193
+ import traceback
194
+ traceback.print_exc()
195
+ return image_path
196
+
197
+
198
+ # ===== 测试代码 =====
199
+ if __name__ == "__main__":
200
+ print("="*60)
201
+ print("Testing Counting Model")
202
+ print("="*60)
203
+
204
+ # 测试模型加载
205
+ model, device = load_model(use_box=False)
206
+
207
+ if model is not None:
208
+ print("\n" + "="*60)
209
+ print("Model loaded successfully, testing inference...")
210
+ print("="*60)
211
+
212
+ # 测试推理
213
+ test_image = "example_imgs/1977_Well_F-5_Field_1.png"
214
+
215
+ if os.path.exists(test_image):
216
+ result = run(
217
+ model,
218
+ test_image,
219
+ box=None,
220
+ device=device,
221
+ visualize=True
222
+ )
223
+
224
+ if 'error' not in result:
225
+ print("\n" + "="*60)
226
+ print("Inference Results:")
227
+ print("="*60)
228
+ print(f"Count: {result['count']:.1f}")
229
+ print(f"Density map shape: {result['density_map'].shape}")
230
+ if result['visualized_path']:
231
+ print(f"Visualization saved to: {result['visualized_path']}")
232
+ else:
233
+ print(f"\n❌ Inference failed: {result['error']}")
234
+ else:
235
+ print(f"\n⚠️ Test image not found: {test_image}")
236
+ else:
237
+ print("\n❌ Model loading failed")
inference_seg.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from huggingface_hub import hf_hub_download
4
+ from segmentation import SegmentationModule
5
+
6
+ MODEL = None
7
+ DEVICE = torch.device("cpu")
8
+
9
+ def load_model(use_box=False):
10
+ global MODEL, DEVICE
11
+ MODEL = SegmentationModule(use_box=use_box)
12
+
13
+ ckpt_path = hf_hub_download(
14
+ repo_id="phoebe777777/111",
15
+ filename="microscopy_matching_seg.pth",
16
+ token=None,
17
+ force_download=False
18
+ )
19
+ MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
20
+ MODEL.eval()
21
+ if torch.cuda.is_available():
22
+ DEVICE = torch.device("cuda")
23
+ MODEL.move_to_device(DEVICE)
24
+ print("✅ Model moved to CUDA")
25
+ else:
26
+ DEVICE = torch.device("cpu")
27
+ MODEL.move_to_device(DEVICE)
28
+ print("✅ Model on CPU")
29
+ return MODEL, DEVICE
30
+
31
+
32
+ @torch.no_grad()
33
+ def run(model, img_path, box=None, device="cpu"):
34
+ print("DEVICE:", device)
35
+ model.move_to_device(device)
36
+ model.eval()
37
+ with torch.no_grad():
38
+ if box is not None:
39
+ use_box = True
40
+ else:
41
+ use_box = False
42
+ model.use_box = use_box
43
+ output = model(img_path, box=box)
44
+ mask = output
45
+ return mask
46
+ # import os
47
+ # import torch
48
+ # import numpy as np
49
+ # from huggingface_hub import hf_hub_download
50
+ # from segmentation import SegmentationModule
51
+
52
+ # MODEL = None
53
+ # DEVICE = torch.device("cpu")
54
+
55
+ # def load_model(use_box=False):
56
+ # global MODEL, DEVICE
57
+
58
+ # # === 优化1: 使用 /data 缓存模型,避免写入 .cache ===
59
+ # cache_dir = "/data/cellseg_model_cache"
60
+ # os.makedirs(cache_dir, exist_ok=True)
61
+
62
+ # ckpt_path = hf_hub_download(
63
+ # repo_id="Shengxiao0709/cellsegmodel",
64
+ # filename="microscopy_matching_seg.pth",
65
+ # token=None,
66
+ # local_dir=cache_dir, # ✅ 下载到 /data
67
+ # local_dir_use_symlinks=False, # ✅ 避免软链接问题
68
+ # force_download=False # ✅ 已存在时不重复下载
69
+ # )
70
+
71
+ # # === 优化2: 加载模型 ===
72
+ # MODEL = SegmentationModule(use_box=use_box)
73
+ # state_dict = torch.load(ckpt_path, map_location="cpu")
74
+ # MODEL.load_state_dict(state_dict, strict=False)
75
+ # MODEL.eval()
76
+
77
+ # DEVICE = torch.device("cpu")
78
+ # print(f"✅ Model loaded from {ckpt_path}")
79
+ # return MODEL, DEVICE
80
+
81
+
82
+ # @torch.no_grad()
83
+ # def run(model, img_path, box=None, device="cpu"):
84
+ # output = model(img_path, box=box)
85
+ # mask = output["pred"]
86
+ # mask = (mask > 0).astype(np.uint8)
87
+ # return mask
inference_track.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # inference_track.py
2
+ # 视频跟踪模型推理模块
3
+
4
+ import torch
5
+ import numpy as np
6
+ import os
7
+ from pathlib import Path
8
+ from tqdm import tqdm
9
+ from huggingface_hub import hf_hub_download
10
+ from tracking_one import TrackingModule
11
+ from models.tra_post_model.trackastra.tracking import graph_to_ctc
12
+
13
+ MODEL = None
14
+ DEVICE = torch.device("cpu")
15
+
16
+ def load_model(use_box=False):
17
+ """
18
+ 加载跟踪模型
19
+
20
+ Args:
21
+ use_box: 是否使用边界框
22
+
23
+ Returns:
24
+ model: 加载的模型
25
+ device: 设备
26
+ """
27
+ global MODEL, DEVICE
28
+
29
+ try:
30
+ print("🔄 Loading tracking model...")
31
+
32
+ # 初始化模型
33
+ MODEL = TrackingModule(use_box=use_box)
34
+
35
+ # 从 Hugging Face Hub 下载权重
36
+ ckpt_path = hf_hub_download(
37
+ repo_id="phoebe777777/111",
38
+ filename="microscopy_matching_tra.pth",
39
+ token=None,
40
+ force_download=False
41
+ )
42
+
43
+ print(f"✅ Checkpoint downloaded: {ckpt_path}")
44
+
45
+ # 加载权重
46
+ MODEL.load_state_dict(
47
+ torch.load(ckpt_path, map_location="cpu"),
48
+ strict=True
49
+ )
50
+ MODEL.eval()
51
+
52
+ # 设置设备
53
+ if torch.cuda.is_available():
54
+ DEVICE = torch.device("cuda")
55
+ MODEL.move_to_device(DEVICE)
56
+ print("✅ Model moved to CUDA")
57
+ else:
58
+ DEVICE = torch.device("cpu")
59
+ MODEL.move_to_device(DEVICE)
60
+ print("✅ Model on CPU")
61
+
62
+ print("✅ Tracking model loaded successfully")
63
+ return MODEL, DEVICE
64
+
65
+ except Exception as e:
66
+ print(f"❌ Error loading tracking model: {e}")
67
+ import traceback
68
+ traceback.print_exc()
69
+ return None, torch.device("cpu")
70
+
71
+
72
+ @torch.no_grad()
73
+ def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
74
+ """
75
+ 运行视频跟踪推理
76
+
77
+ Args:
78
+ model: 跟踪模型
79
+ video_dir: 视频帧序列目录 (包含连续的图像文件)
80
+ box: 边界框 (可选)
81
+ device: 设备
82
+ output_dir: 输出目录
83
+
84
+ Returns:
85
+ result_dict: {
86
+ 'track_graph': TrackGraph对象,
87
+ 'masks': 分割掩码数组 (T, H, W),
88
+ 'output_dir': 输出目录路径,
89
+ 'num_tracks': 跟踪轨迹数量
90
+ }
91
+ """
92
+ if model is None:
93
+ return {
94
+ 'track_graph': None,
95
+ 'masks': None,
96
+ 'output_dir': None,
97
+ 'num_tracks': 0,
98
+ 'error': 'Model not loaded'
99
+ }
100
+
101
+ try:
102
+ print(f"🔄 Running tracking inference on {video_dir}")
103
+
104
+ # 运行跟踪
105
+ track_graph, masks = model.track(
106
+ file_dir=video_dir,
107
+ boxes=box,
108
+ mode="greedy", # 可选: "greedy", "greedy_nodiv", "ilp"
109
+ dataname="tracking_result"
110
+ )
111
+
112
+ # 创建输出目录
113
+ if not os.path.exists(output_dir):
114
+ os.makedirs(output_dir)
115
+
116
+ # 转换为CTC格式并保存
117
+ print("🔄 Converting to CTC format...")
118
+ ctc_tracks, masks_tracked = graph_to_ctc(
119
+ track_graph,
120
+ masks,
121
+ outdir=output_dir,
122
+ )
123
+ print(f"✅ CTC results saved to {output_dir}")
124
+
125
+ # num_tracks = len(track_graph.tracks())
126
+
127
+ print(f"✅ Tracking completed")
128
+
129
+ result = {
130
+ 'track_graph': track_graph,
131
+ 'masks': masks,
132
+ 'masks_tracked': masks_tracked,
133
+ 'output_dir': output_dir,
134
+ # 'num_tracks': num_tracks
135
+ }
136
+
137
+ return result
138
+
139
+ except Exception as e:
140
+ print(f"❌ Tracking inference error: {e}")
141
+ import traceback
142
+ traceback.print_exc()
143
+ return {
144
+ 'track_graph': None,
145
+ 'masks': None,
146
+ 'output_dir': None,
147
+ 'num_tracks': 0,
148
+ 'error': str(e)
149
+ }
150
+
151
+
152
+ def visualize_tracking_result(masks_tracked, output_path):
153
+ """
154
+ 可视化跟踪结果 (可选)
155
+
156
+ Args:
157
+ masks_tracked: 跟踪后的掩码 (T, H, W)
158
+ output_path: 输出视频路径
159
+
160
+ Returns:
161
+ output_path: 视频文件路径
162
+ """
163
+ try:
164
+ import cv2
165
+ import matplotlib.pyplot as plt
166
+ from matplotlib import cm
167
+
168
+ # 获取时间帧数
169
+ T, H, W = masks_tracked.shape
170
+
171
+ # 创建颜色映射
172
+ unique_ids = np.unique(masks_tracked)
173
+ num_colors = len(unique_ids)
174
+ cmap = cm.get_cmap('tab20', num_colors)
175
+
176
+ # 创建视频写入器
177
+ fourcc = cv2.VideoWriter_fourcc(*'mp4v')
178
+ out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H))
179
+
180
+ for t in range(T):
181
+ frame = masks_tracked[t]
182
+
183
+ # 创建彩色图像
184
+ colored_frame = np.zeros((H, W, 3), dtype=np.uint8)
185
+ for i, obj_id in enumerate(unique_ids):
186
+ if obj_id == 0:
187
+ continue
188
+ mask = (frame == obj_id)
189
+ color = np.array(cmap(i % num_colors)[:3]) * 255
190
+ colored_frame[mask] = color
191
+
192
+ # 转换为BGR (OpenCV格式)
193
+ colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
194
+ out.write(colored_frame_bgr)
195
+
196
+ out.release()
197
+ print(f"✅ Visualization saved to {output_path}")
198
+ return output_path
199
+
200
+ except Exception as e:
201
+ print(f"❌ Visualization error: {e}")
202
+ return None
models/.DS_Store ADDED
Binary file (6.15 kB). View file
 
models/enc_model/__init__.py ADDED
File without changes
models/enc_model/backbone.py ADDED
@@ -0,0 +1,64 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from torch.nn import functional as F
4
+ from torchvision import models
5
+ from torchvision.ops.misc import FrozenBatchNorm2d
6
+
7
+
8
+ class Backbone(nn.Module):
9
+
10
+ def __init__(
11
+ self,
12
+ name: str,
13
+ pretrained: bool,
14
+ dilation: bool,
15
+ reduction: int,
16
+ swav: bool,
17
+ requires_grad: bool
18
+ ):
19
+
20
+ super(Backbone, self).__init__()
21
+
22
+ resnet = getattr(models, name)(
23
+ replace_stride_with_dilation=[False, False, dilation],
24
+ pretrained=pretrained, norm_layer=FrozenBatchNorm2d
25
+ )
26
+
27
+ self.backbone = resnet
28
+ self.reduction = reduction
29
+
30
+ if name == 'resnet50' and swav:
31
+ checkpoint = torch.hub.load_state_dict_from_url(
32
+ 'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar',
33
+ map_location="cpu"
34
+ )
35
+ state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
36
+ self.backbone.load_state_dict(state_dict, strict=False)
37
+
38
+ # concatenation of layers 2, 3 and 4
39
+ self.num_channels = 896 if name in ['resnet18', 'resnet34'] else 3584
40
+
41
+ for n, param in self.backbone.named_parameters():
42
+ if 'layer2' not in n and 'layer3' not in n and 'layer4' not in n:
43
+ param.requires_grad_(False)
44
+ else:
45
+ param.requires_grad_(requires_grad)
46
+
47
+ def forward(self, x):
48
+ size = x.size(-2) // self.reduction, x.size(-1) // self.reduction
49
+ x = self.backbone.conv1(x)
50
+ x = self.backbone.bn1(x)
51
+ x = self.backbone.relu(x)
52
+ x = self.backbone.maxpool(x)
53
+
54
+ x = self.backbone.layer1(x)
55
+ x = layer2 = self.backbone.layer2(x)
56
+ x = layer3 = self.backbone.layer3(x)
57
+ x = layer4 = self.backbone.layer4(x)
58
+
59
+ x = torch.cat([
60
+ F.interpolate(f, size=size, mode='bilinear', align_corners=True)
61
+ for f in [layer2, layer3, layer4]
62
+ ], dim=1)
63
+
64
+ return x
models/enc_model/loca.py ADDED
@@ -0,0 +1,232 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .backbone import Backbone
2
+ from .transformer import TransformerEncoder
3
+ from .ope import OPEModule
4
+ from .positional_encoding import PositionalEncodingsFixed
5
+ from .regression_head import DensityMapRegressor
6
+
7
+ import torch
8
+ from torch import nn
9
+ from torch.nn import functional as F
10
+
11
+
12
+ class LOCA(nn.Module):
13
+
14
+ def __init__(
15
+ self,
16
+ image_size: int,
17
+ num_encoder_layers: int,
18
+ num_ope_iterative_steps: int,
19
+ num_objects: int,
20
+ emb_dim: int,
21
+ num_heads: int,
22
+ kernel_dim: int,
23
+ backbone_name: str,
24
+ swav_backbone: bool,
25
+ train_backbone: bool,
26
+ reduction: int,
27
+ dropout: float,
28
+ layer_norm_eps: float,
29
+ mlp_factor: int,
30
+ norm_first: bool,
31
+ activation: nn.Module,
32
+ norm: bool,
33
+ zero_shot: bool,
34
+ ):
35
+
36
+ super(LOCA, self).__init__()
37
+
38
+ self.emb_dim = emb_dim
39
+ self.num_objects = num_objects
40
+ self.reduction = reduction
41
+ self.kernel_dim = kernel_dim
42
+ self.image_size = image_size
43
+ self.zero_shot = zero_shot
44
+ self.num_heads = num_heads
45
+ self.num_encoder_layers = num_encoder_layers
46
+
47
+ self.backbone = Backbone(
48
+ backbone_name, pretrained=True, dilation=False, reduction=reduction,
49
+ swav=swav_backbone, requires_grad=train_backbone
50
+ )
51
+ self.input_proj = nn.Conv2d(
52
+ self.backbone.num_channels, emb_dim, kernel_size=1
53
+ )
54
+
55
+ if num_encoder_layers > 0:
56
+ self.encoder = TransformerEncoder(
57
+ num_encoder_layers, emb_dim, num_heads, dropout, layer_norm_eps,
58
+ mlp_factor, norm_first, activation, norm
59
+ )
60
+
61
+ self.ope = OPEModule(
62
+ num_ope_iterative_steps, emb_dim, kernel_dim, num_objects, num_heads,
63
+ reduction, layer_norm_eps, mlp_factor, norm_first, activation, norm, zero_shot
64
+ )
65
+
66
+ self.regression_head = DensityMapRegressor(emb_dim, reduction)
67
+ self.aux_heads = nn.ModuleList([
68
+ DensityMapRegressor(emb_dim, reduction)
69
+ for _ in range(num_ope_iterative_steps - 1)
70
+ ])
71
+
72
+ self.pos_emb = PositionalEncodingsFixed(emb_dim)
73
+
74
+ self.attn_norm = nn.LayerNorm(normalized_shape=(64, 64))
75
+ self.fuse = nn.Sequential(
76
+ nn.Conv2d(324, 256, kernel_size=1, stride=1),
77
+ nn.LeakyReLU(),
78
+ nn.LayerNorm((64, 64))
79
+ )
80
+
81
+ # self.fuse1 = nn.Sequential(
82
+ # nn.Conv2d(322, 256, kernel_size=1, stride=1),
83
+ # nn.LeakyReLU(),
84
+ # nn.LayerNorm((64, 64))
85
+ # )
86
+
87
+ def forward_before_reg(self, x, bboxes):
88
+ num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects
89
+ # backbone
90
+ backbone_features = self.backbone(x)
91
+ # prepare the encoder input
92
+ src = self.input_proj(backbone_features)
93
+ bs, c, h, w = src.size()
94
+ pos_emb = self.pos_emb(bs, h, w, src.device).flatten(2).permute(2, 0, 1)
95
+ src = src.flatten(2).permute(2, 0, 1)
96
+
97
+ # push through the encoder
98
+ if self.num_encoder_layers > 0:
99
+ image_features = self.encoder(src, pos_emb, src_key_padding_mask=None, src_mask=None)
100
+ else:
101
+ image_features = src
102
+
103
+ # prepare OPE input
104
+ f_e = image_features.permute(1, 2, 0).reshape(-1, self.emb_dim, h, w)
105
+
106
+ all_prototypes = self.ope(f_e, pos_emb, bboxes) # [3, 27, 1, 256]
107
+
108
+ outputs = list()
109
+ response_maps_list = []
110
+ for i in range(all_prototypes.size(0)):
111
+ prototypes = all_prototypes[i, ...].permute(1, 0, 2).reshape(
112
+ bs, num_objects, self.kernel_dim, self.kernel_dim, -1
113
+ ).permute(0, 1, 4, 2, 3).flatten(0, 2)[:, None, ...] # [768, 1, 3, 3]
114
+
115
+ response_maps = F.conv2d(
116
+ torch.cat([f_e for _ in range(num_objects)], dim=1).flatten(0, 1).unsqueeze(0),
117
+ prototypes,
118
+ bias=None,
119
+ padding=self.kernel_dim // 2,
120
+ groups=prototypes.size(0)
121
+ ).view(
122
+ bs, num_objects, self.emb_dim, h, w
123
+ ).max(dim=1)[0]
124
+
125
+ # # send through regression heads
126
+ # if i == all_prototypes.size(0) - 1:
127
+ # predicted_dmaps = self.regression_head(response_maps)
128
+ # else:
129
+ # predicted_dmaps = self.aux_heads[i](response_maps)
130
+ # outputs.append(predicted_dmaps)
131
+ response_maps_list.append(response_maps)
132
+
133
+ out = {
134
+ # "pred": outputs[-1],
135
+ "feature_bf_regression": response_maps_list[-1],
136
+ # "aux_pred": outputs[:-1],
137
+ "aux_feature_bf_regression": response_maps_list[:-1]
138
+ }
139
+
140
+ return out
141
+
142
+ def forward_reg(self, response_maps, attn_stack, unet_feature):
143
+ attn_stack = self.attn_norm(attn_stack)
144
+ attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
145
+ unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
146
+ unet_feature = unet_feature * attn_stack_mean
147
+ if unet_feature.shape[1] == 322:
148
+ unet_feature = self.fuse1(unet_feature)
149
+ else:
150
+ unet_feature = self.fuse(unet_feature)
151
+
152
+ response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
153
+
154
+ outputs = []
155
+ for i in range(len(response_maps)):
156
+ response_map = response_maps[i] + unet_feature
157
+ if i == len(response_maps) - 1:
158
+ predicted_dmaps = self.regression_head(response_map)
159
+ else:
160
+ predicted_dmaps = self.aux_heads[i](response_map)
161
+ outputs.append(predicted_dmaps)
162
+
163
+ return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
164
+
165
+ def forward_reg1(self, response_maps, self_attn):
166
+ # attn_stack = self.attn_norm(attn_stack)
167
+ # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
168
+ # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
169
+ # unet_feature = unet_feature * attn_stack_mean
170
+ # if unet_feature.shape[1] == 322:
171
+ # unet_feature = self.fuse1(unet_feature)
172
+ # else:
173
+ # unet_feature = self.fuse(unet_feature)
174
+
175
+
176
+
177
+ response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
178
+
179
+ outputs = []
180
+ for i in range(len(response_maps)):
181
+ response_map = response_maps[i] + self_attn
182
+ if i == len(response_maps) - 1:
183
+ predicted_dmaps = self.regression_head(response_map)
184
+ else:
185
+ predicted_dmaps = self.aux_heads[i](response_map)
186
+ outputs.append(predicted_dmaps)
187
+
188
+ return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
189
+
190
+ def forward_reg_without_unet(self, response_maps, attn_stack):
191
+ # attn_stack = self.attn_norm(attn_stack)
192
+ attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
193
+
194
+ response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
195
+
196
+ outputs = []
197
+ for i in range(len(response_maps)):
198
+ response_map = response_maps[i] * attn_stack_mean * 0.5 + response_maps[i]
199
+ if i == len(response_maps) - 1:
200
+ predicted_dmaps = self.regression_head(response_map)
201
+ else:
202
+ predicted_dmaps = self.aux_heads[i](response_map)
203
+ outputs.append(predicted_dmaps)
204
+
205
+ return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
206
+
207
+
208
+ def build_model(args):
209
+
210
+ assert args.backbone in ['resnet18', 'resnet50', 'resnet101']
211
+ assert args.reduction in [4, 8, 16]
212
+
213
+ return LOCA(
214
+ image_size=args.image_size,
215
+ num_encoder_layers=args.num_enc_layers,
216
+ num_ope_iterative_steps=args.num_ope_iterative_steps,
217
+ num_objects=args.num_objects,
218
+ zero_shot=args.zero_shot,
219
+ emb_dim=args.emb_dim,
220
+ num_heads=args.num_heads,
221
+ kernel_dim=args.kernel_dim,
222
+ backbone_name=args.backbone,
223
+ swav_backbone=args.swav_backbone,
224
+ train_backbone=args.backbone_lr > 0,
225
+ reduction=args.reduction,
226
+ dropout=args.dropout,
227
+ layer_norm_eps=1e-5,
228
+ mlp_factor=8,
229
+ norm_first=args.pre_norm,
230
+ activation=nn.GELU,
231
+ norm=True,
232
+ )
models/enc_model/loca_args.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+
3
+
4
+ def get_argparser():
5
+
6
+ parser = argparse.ArgumentParser("LOCA parser", add_help=False)
7
+
8
+ parser.add_argument('--model_name', default='loca_few_shot', type=str)
9
+ parser.add_argument(
10
+ '--data_path',
11
+ default='./data/FSC147_384_V2',
12
+ type=str
13
+ )
14
+ parser.add_argument(
15
+ '--model_path',
16
+ default='ckpt',
17
+ type=str
18
+ )
19
+ parser.add_argument('--backbone', default='resnet50', type=str)
20
+ parser.add_argument('--swav_backbone', action='store_true', default=True)
21
+ parser.add_argument('--reduction', default=8, type=int)
22
+ parser.add_argument('--image_size', default=512, type=int)
23
+ parser.add_argument('--num_enc_layers', default=3, type=int)
24
+ parser.add_argument('--num_ope_iterative_steps', default=3, type=int)
25
+ parser.add_argument('--emb_dim', default=256, type=int)
26
+ parser.add_argument('--num_heads', default=8, type=int)
27
+ parser.add_argument('--kernel_dim', default=3, type=int)
28
+ parser.add_argument('--num_objects', default=3, type=int)
29
+ parser.add_argument('--epochs', default=200, type=int)
30
+ parser.add_argument('--resume_training', action='store_true')
31
+ parser.add_argument('--lr', default=1e-4, type=float)
32
+ parser.add_argument('--backbone_lr', default=0, type=float)
33
+ parser.add_argument('--lr_drop', default=200, type=int)
34
+ parser.add_argument('--weight_decay', default=1e-4, type=float)
35
+ parser.add_argument('--batch_size', default=1, type=int)
36
+ parser.add_argument('--dropout', default=0.1, type=float)
37
+ parser.add_argument('--num_workers', default=8, type=int)
38
+ parser.add_argument('--max_grad_norm', default=0.1, type=float)
39
+ parser.add_argument('--aux_weight', default=0.3, type=float)
40
+ parser.add_argument('--tiling_p', default=0.5, type=float)
41
+ parser.add_argument('--zero_shot', action='store_true')
42
+ parser.add_argument('--pre_norm', action='store_true', default=True)
43
+
44
+ return parser
models/enc_model/mlp.py ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+
3
+
4
+ class MLP(nn.Module):
5
+
6
+ def __init__(
7
+ self,
8
+ input_dim: int,
9
+ hidden_dim: int,
10
+ dropout: float,
11
+ activation: nn.Module
12
+ ):
13
+ super(MLP, self).__init__()
14
+
15
+ self.linear1 = nn.Linear(input_dim, hidden_dim)
16
+ self.linear2 = nn.Linear(hidden_dim, input_dim)
17
+ self.dropout = nn.Dropout(dropout)
18
+ self.activation = activation()
19
+
20
+ def forward(self, x):
21
+ return (
22
+ self.linear2(self.dropout(self.activation(self.linear1(x))))
23
+ )
models/enc_model/ope.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mlp import MLP
2
+ from .positional_encoding import PositionalEncodingsFixed
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from torchvision.ops import roi_align
8
+
9
+
10
+ class OPEModule(nn.Module):
11
+
12
+ def __init__(
13
+ self,
14
+ num_iterative_steps: int,
15
+ emb_dim: int,
16
+ kernel_dim: int,
17
+ num_objects: int,
18
+ num_heads: int,
19
+ reduction: int,
20
+ layer_norm_eps: float,
21
+ mlp_factor: int,
22
+ norm_first: bool,
23
+ activation: nn.Module,
24
+ norm: bool,
25
+ zero_shot: bool,
26
+ ):
27
+
28
+ super(OPEModule, self).__init__()
29
+
30
+ self.num_iterative_steps = num_iterative_steps
31
+ self.zero_shot = zero_shot
32
+ self.kernel_dim = kernel_dim
33
+ self.num_objects = num_objects
34
+ self.emb_dim = emb_dim
35
+ self.reduction = reduction
36
+
37
+ if num_iterative_steps > 0:
38
+ self.iterative_adaptation = IterativeAdaptationModule(
39
+ num_layers=num_iterative_steps, emb_dim=emb_dim, num_heads=num_heads,
40
+ dropout=0, layer_norm_eps=layer_norm_eps,
41
+ mlp_factor=mlp_factor, norm_first=norm_first,
42
+ activation=activation, norm=norm,
43
+ zero_shot=zero_shot
44
+ )
45
+
46
+ if not self.zero_shot:
47
+ self.shape_or_objectness = nn.Sequential(
48
+ nn.Linear(2, 64),
49
+ nn.ReLU(),
50
+ nn.Linear(64, emb_dim),
51
+ nn.ReLU(),
52
+ nn.Linear(emb_dim, self.kernel_dim**2 * emb_dim)
53
+ )
54
+ else:
55
+ self.shape_or_objectness = nn.Parameter(
56
+ torch.empty((self.num_objects, self.kernel_dim**2, emb_dim))
57
+ )
58
+ nn.init.normal_(self.shape_or_objectness)
59
+
60
+ self.pos_emb = PositionalEncodingsFixed(emb_dim)
61
+
62
+ def forward(self, f_e, pos_emb, bboxes):
63
+ bs, _, h, w = f_e.size()
64
+ # extract the shape features or objectness
65
+ if not self.zero_shot:
66
+ box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device)
67
+ box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0]
68
+ box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1]
69
+ shape_or_objectness = self.shape_or_objectness(box_hw).reshape(
70
+ bs, -1, self.kernel_dim ** 2, self.emb_dim
71
+ ).flatten(1, 2).transpose(0, 1)
72
+ else:
73
+ shape_or_objectness = self.shape_or_objectness.expand(
74
+ bs, -1, -1, -1
75
+ ).flatten(1, 2).transpose(0, 1)
76
+
77
+ # if not zero shot add appearance
78
+ if not self.zero_shot:
79
+ # reshape bboxes into the format suitable for roi_align
80
+ num_of_boxes = bboxes.size(1)
81
+ bboxes = torch.cat([
82
+ torch.arange(
83
+ bs, requires_grad=False
84
+ ).to(bboxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
85
+ bboxes.flatten(0, 1),
86
+ ], dim=1)
87
+ appearance = roi_align(
88
+ f_e,
89
+ boxes=bboxes, output_size=self.kernel_dim,
90
+ spatial_scale=1.0 / self.reduction, aligned=True
91
+ ).permute(0, 2, 3, 1).reshape(
92
+ bs, num_of_boxes * self.kernel_dim ** 2, -1
93
+ ).transpose(0, 1)
94
+ else:
95
+ num_of_boxes = self.num_objects
96
+ appearance = None
97
+
98
+ query_pos_emb = self.pos_emb(
99
+ bs, self.kernel_dim, self.kernel_dim, f_e.device
100
+ ).flatten(2).permute(2, 0, 1).repeat(num_of_boxes, 1, 1)
101
+
102
+ if self.num_iterative_steps > 0:
103
+ memory = f_e.flatten(2).permute(2, 0, 1)
104
+ all_prototypes = self.iterative_adaptation(
105
+ shape_or_objectness, appearance, memory, pos_emb, query_pos_emb
106
+ )
107
+ else:
108
+ if shape_or_objectness is not None and appearance is not None:
109
+ all_prototypes = (shape_or_objectness + appearance).unsqueeze(0)
110
+ else:
111
+ all_prototypes = (
112
+ shape_or_objectness if shape_or_objectness is not None else appearance
113
+ ).unsqueeze(0)
114
+
115
+ return all_prototypes
116
+
117
+
118
+
119
+ class IterativeAdaptationModule(nn.Module):
120
+
121
+ def __init__(
122
+ self,
123
+ num_layers: int,
124
+ emb_dim: int,
125
+ num_heads: int,
126
+ dropout: float,
127
+ layer_norm_eps: float,
128
+ mlp_factor: int,
129
+ norm_first: bool,
130
+ activation: nn.Module,
131
+ norm: bool,
132
+ zero_shot: bool
133
+ ):
134
+
135
+ super(IterativeAdaptationModule, self).__init__()
136
+
137
+ self.layers = nn.ModuleList([
138
+ IterativeAdaptationLayer(
139
+ emb_dim, num_heads, dropout, layer_norm_eps,
140
+ mlp_factor, norm_first, activation, zero_shot
141
+ ) for i in range(num_layers)
142
+ ])
143
+
144
+ self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
145
+
146
+ def forward(
147
+ self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask=None, memory_mask=None,
148
+ tgt_key_padding_mask=None, memory_key_padding_mask=None
149
+ ):
150
+
151
+ output = tgt
152
+ outputs = list()
153
+ for i, layer in enumerate(self.layers):
154
+ output = layer(
155
+ output, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask,
156
+ tgt_key_padding_mask, memory_key_padding_mask
157
+ )
158
+ outputs.append(self.norm(output))
159
+
160
+ return torch.stack(outputs)
161
+
162
+
163
+ class IterativeAdaptationLayer(nn.Module):
164
+
165
+ def __init__(
166
+ self,
167
+ emb_dim: int,
168
+ num_heads: int,
169
+ dropout: float,
170
+ layer_norm_eps: float,
171
+ mlp_factor: int,
172
+ norm_first: bool,
173
+ activation: nn.Module,
174
+ zero_shot: bool
175
+ ):
176
+ super(IterativeAdaptationLayer, self).__init__()
177
+
178
+ self.norm_first = norm_first
179
+ self.zero_shot = zero_shot
180
+
181
+ if not self.zero_shot:
182
+ self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
183
+ self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
184
+ self.norm3 = nn.LayerNorm(emb_dim, layer_norm_eps)
185
+ if not self.zero_shot:
186
+ self.dropout1 = nn.Dropout(dropout)
187
+ self.dropout2 = nn.Dropout(dropout)
188
+ self.dropout3 = nn.Dropout(dropout)
189
+
190
+ if not self.zero_shot:
191
+ self.self_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout)
192
+ self.enc_dec_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout)
193
+
194
+ self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
195
+
196
+ def with_emb(self, x, emb):
197
+ return x if emb is None else x + emb
198
+
199
+ def forward(
200
+ self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask,
201
+ tgt_key_padding_mask, memory_key_padding_mask
202
+ ):
203
+ if self.norm_first:
204
+ if not self.zero_shot:
205
+ tgt_norm = self.norm1(tgt)
206
+ tgt = tgt + self.dropout1(self.self_attn(
207
+ query=self.with_emb(tgt_norm, query_pos_emb),
208
+ key=self.with_emb(appearance, query_pos_emb),
209
+ value=appearance,
210
+ attn_mask=tgt_mask,
211
+ key_padding_mask=tgt_key_padding_mask
212
+ )[0])
213
+
214
+ tgt_norm = self.norm2(tgt)
215
+ tgt = tgt + self.dropout2(self.enc_dec_attn(
216
+ query=self.with_emb(tgt_norm, query_pos_emb),
217
+ key=memory+pos_emb,
218
+ value=memory,
219
+ attn_mask=memory_mask,
220
+ key_padding_mask=memory_key_padding_mask
221
+ )[0])
222
+ tgt_norm = self.norm3(tgt)
223
+ tgt = tgt + self.dropout3(self.mlp(tgt_norm))
224
+
225
+ else:
226
+ if not self.zero_shot:
227
+ tgt = self.norm1(tgt + self.dropout1(self.self_attn(
228
+ query=self.with_emb(tgt, query_pos_emb),
229
+ key=self.with_emb(appearance),
230
+ value=appearance,
231
+ attn_mask=tgt_mask,
232
+ key_padding_mask=tgt_key_padding_mask
233
+ )[0]))
234
+
235
+ tgt = self.norm2(tgt + self.dropout2(self.enc_dec_attn(
236
+ query=self.with_emb(tgt, query_pos_emb),
237
+ key=memory+pos_emb,
238
+ value=memory,
239
+ attn_mask=memory_mask,
240
+ key_padding_mask=memory_key_padding_mask
241
+ )[0]))
242
+
243
+ tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
244
+
245
+ return tgt
models/enc_model/positional_encoding.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+
5
+ class PositionalEncodingsFixed(nn.Module):
6
+
7
+ def __init__(self, emb_dim, temperature=10000):
8
+
9
+ super(PositionalEncodingsFixed, self).__init__()
10
+
11
+ self.emb_dim = emb_dim
12
+ self.temperature = temperature
13
+
14
+ def _1d_pos_enc(self, mask, dim):
15
+ temp = torch.arange(self.emb_dim // 2).float().to(mask.device)
16
+ temp = self.temperature ** (2 * (temp.div(2, rounding_mode='floor')) / self.emb_dim)
17
+
18
+ enc = (~mask).cumsum(dim).float().unsqueeze(-1) / temp
19
+ enc = torch.stack([
20
+ enc[..., 0::2].sin(), enc[..., 1::2].cos()
21
+ ], dim=-1).flatten(-2)
22
+
23
+ return enc
24
+
25
+ def forward(self, bs, h, w, device):
26
+ mask = torch.zeros(bs, h, w, dtype=torch.bool, requires_grad=False, device=device)
27
+ x = self._1d_pos_enc(mask, dim=2)
28
+ y = self._1d_pos_enc(mask, dim=1)
29
+
30
+ return torch.cat([y, x], dim=3).permute(0, 3, 1, 2)
models/enc_model/regression_head.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from torch import nn
2
+ import torch
3
+
4
+
5
+ class UpsamplingLayer(nn.Module):
6
+
7
+ def __init__(self, in_channels, out_channels, leaky=True):
8
+
9
+ super(UpsamplingLayer, self).__init__()
10
+
11
+ self.layer = nn.Sequential(
12
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
13
+ nn.LeakyReLU() if leaky else nn.ReLU(),
14
+ nn.UpsamplingBilinear2d(scale_factor=2)
15
+ )
16
+
17
+ def forward(self, x):
18
+ return self.layer(x)
19
+
20
+
21
+ class DensityMapRegressor(nn.Module):
22
+
23
+ def __init__(self, in_channels, reduction):
24
+
25
+ super(DensityMapRegressor, self).__init__()
26
+
27
+ if reduction == 8:
28
+ self.regressor = nn.Sequential(
29
+ UpsamplingLayer(in_channels, 128),
30
+ UpsamplingLayer(128, 64),
31
+ UpsamplingLayer(64, 32),
32
+ nn.Conv2d(32, 1, kernel_size=1),
33
+ nn.LeakyReLU()
34
+ )
35
+ elif reduction == 16:
36
+ self.regressor = nn.Sequential(
37
+ UpsamplingLayer(in_channels, 128),
38
+ UpsamplingLayer(128, 64),
39
+ UpsamplingLayer(64, 32),
40
+ UpsamplingLayer(32, 16),
41
+ nn.Conv2d(16, 1, kernel_size=1),
42
+ nn.LeakyReLU()
43
+ )
44
+
45
+ self.reset_parameters()
46
+
47
+ def forward(self, x):
48
+ return self.regressor(x)
49
+
50
+ def reset_parameters(self):
51
+ for module in self.modules():
52
+ if isinstance(module, nn.Conv2d):
53
+ nn.init.normal_(module.weight, std=0.01)
54
+ if module.bias is not None:
55
+ nn.init.constant_(module.bias, 0)
56
+
57
+
58
+ class DensityMapRegressor_(nn.Module):
59
+
60
+ def __init__(self, in_channels, reduction):
61
+
62
+ super(DensityMapRegressor, self).__init__()
63
+
64
+ if reduction == 8:
65
+ self.regressor = nn.Sequential(
66
+ UpsamplingLayer(in_channels, 128),
67
+ UpsamplingLayer(128, 64),
68
+ UpsamplingLayer(64, 32),
69
+ nn.Conv2d(32, 1, kernel_size=1),
70
+ nn.LeakyReLU()
71
+ )
72
+ elif reduction == 16:
73
+ self.regressor = nn.Sequential(
74
+ UpsamplingLayer(in_channels, 128),
75
+ UpsamplingLayer(128, 64),
76
+ UpsamplingLayer(64, 32),
77
+ UpsamplingLayer(32, 16),
78
+ nn.Conv2d(16, 1, kernel_size=1),
79
+ nn.LeakyReLU()
80
+ )
81
+
82
+ self.reset_parameters()
83
+
84
+ def forward(self, x):
85
+ return self.regressor(x)
86
+
87
+ def reset_parameters(self):
88
+ for module in self.modules():
89
+ if isinstance(module, nn.Conv2d):
90
+ nn.init.normal_(module.weight, std=0.01)
91
+ if module.bias is not None:
92
+ nn.init.constant_(module.bias, 0)
models/enc_model/transformer.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .mlp import MLP
2
+
3
+ from torch import nn
4
+
5
+
6
+ class TransformerEncoder(nn.Module):
7
+
8
+ def __init__(
9
+ self,
10
+ num_layers: int,
11
+ emb_dim: int,
12
+ num_heads: int,
13
+ dropout: float,
14
+ layer_norm_eps: float,
15
+ mlp_factor: int,
16
+ norm_first: bool,
17
+ activation: nn.Module,
18
+ norm: bool,
19
+ ):
20
+
21
+ super(TransformerEncoder, self).__init__()
22
+
23
+ self.layers = nn.ModuleList([
24
+ TransformerEncoderLayer(
25
+ emb_dim, num_heads, dropout, layer_norm_eps,
26
+ mlp_factor, norm_first, activation
27
+ ) for _ in range(num_layers)
28
+ ])
29
+
30
+ self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
31
+
32
+ def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
33
+ output = src
34
+ for layer in self.layers:
35
+ output = layer(output, pos_emb, src_mask, src_key_padding_mask)
36
+ return self.norm(output)
37
+
38
+
39
+ class TransformerEncoderLayer(nn.Module):
40
+
41
+ def __init__(
42
+ self,
43
+ emb_dim: int,
44
+ num_heads: int,
45
+ dropout: float,
46
+ layer_norm_eps: float,
47
+ mlp_factor: int,
48
+ norm_first: bool,
49
+ activation: nn.Module,
50
+ ):
51
+ super(TransformerEncoderLayer, self).__init__()
52
+
53
+ self.norm_first = norm_first
54
+
55
+ self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
56
+ self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
57
+ self.dropout1 = nn.Dropout(dropout)
58
+ self.dropout2 = nn.Dropout(dropout)
59
+
60
+ self.self_attn = nn.MultiheadAttention(
61
+ emb_dim, num_heads, dropout
62
+ )
63
+ self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
64
+
65
+ def with_emb(self, x, emb):
66
+ return x if emb is None else x + emb
67
+
68
+ def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
69
+ if self.norm_first:
70
+ src_norm = self.norm1(src)
71
+ q = k = src_norm + pos_emb
72
+ src = src + self.dropout1(self.self_attn(
73
+ query=q,
74
+ key=k,
75
+ value=src_norm,
76
+ attn_mask=src_mask,
77
+ key_padding_mask=src_key_padding_mask
78
+ )[0])
79
+
80
+ src_norm = self.norm2(src)
81
+ src = src + self.dropout2(self.mlp(src_norm))
82
+ else:
83
+ q = k = src + pos_emb
84
+ src = self.norm1(src + self.dropout1(self.self_attn(
85
+ query=q,
86
+ key=k,
87
+ value=src,
88
+ attn_mask=src_mask,
89
+ key_padding_mask=src_key_padding_mask
90
+ )[0]))
91
+
92
+ src = self.norm2(src + self.dropout2(self.mlp(src)))
93
+
94
+ return src
models/enc_model/unet_parts.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ Parts of the U-Net model """
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+
8
+ class DoubleConv(nn.Module):
9
+ """(convolution => [BN] => ReLU) * 2"""
10
+
11
+ def __init__(self, in_channels, out_channels, mid_channels=None):
12
+ super().__init__()
13
+ if not mid_channels:
14
+ mid_channels = out_channels
15
+ self.double_conv = nn.Sequential(
16
+ nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
17
+ nn.BatchNorm2d(mid_channels),
18
+ nn.ReLU(inplace=True),
19
+ nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
20
+ nn.BatchNorm2d(out_channels),
21
+ nn.ReLU(inplace=True)
22
+ )
23
+
24
+ def forward(self, x):
25
+ return self.double_conv(x)
26
+
27
+
28
+ class Down(nn.Module):
29
+ """Downscaling with maxpool then double conv"""
30
+
31
+ def __init__(self, in_channels, out_channels):
32
+ super().__init__()
33
+ self.maxpool_conv = nn.Sequential(
34
+ nn.MaxPool2d(2),
35
+ DoubleConv(in_channels, out_channels)
36
+ )
37
+
38
+ def forward(self, x):
39
+ return self.maxpool_conv(x)
40
+
41
+
42
+ class Up(nn.Module):
43
+ """Upscaling then double conv"""
44
+
45
+ def __init__(self, in_channels, out_channels, bilinear=True):
46
+ super().__init__()
47
+
48
+ # if bilinear, use the normal convolutions to reduce the number of channels
49
+ if bilinear:
50
+ self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
51
+ self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
52
+ else:
53
+ self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
54
+ self.conv = DoubleConv(in_channels, out_channels)
55
+
56
+ def forward(self, x1, x2):
57
+ x1 = self.up(x1)
58
+ # input is CHW
59
+ diffY = x2.size()[2] - x1.size()[2]
60
+ diffX = x2.size()[3] - x1.size()[3]
61
+
62
+ x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
63
+ diffY // 2, diffY - diffY // 2])
64
+ # if you have padding issues, see
65
+ # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
66
+ # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
67
+ x = torch.cat([x2, x1], dim=1)
68
+ return self.conv(x)
69
+
70
+
71
+ class OutConv(nn.Module):
72
+ def __init__(self, in_channels, out_channels):
73
+ super(OutConv, self).__init__()
74
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
75
+
76
+ def forward(self, x):
77
+ return self.conv(x)
models/model.py ADDED
@@ -0,0 +1,653 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import os
5
+ import clip
6
+ import sys
7
+ import numpy as np
8
+ from models.seg_post_model.cellpose.models import CellposeModel
9
+
10
+ from torchvision.ops import roi_align
11
+ def crop_roi_feat(feat, boxes):
12
+ """
13
+ feat: 1 x c x h x w
14
+ boxes: m x 4, 4: [y_tl, x_tl, y_br, x_br]
15
+ """
16
+ _, _, h, w = feat.shape
17
+ out_stride = 512 / h
18
+ boxes_scaled = boxes / out_stride
19
+ boxes_scaled[:, :2] = torch.floor(boxes_scaled[:, :2]) # y_tl, x_tl: floor
20
+ boxes_scaled[:, 2:] = torch.ceil(boxes_scaled[:, 2:]) # y_br, x_br: ceil
21
+ boxes_scaled[:, :2] = torch.clamp_min(boxes_scaled[:, :2], 0)
22
+ boxes_scaled[:, 2] = torch.clamp_max(boxes_scaled[:, 2], h)
23
+ boxes_scaled[:, 3] = torch.clamp_max(boxes_scaled[:, 3], w)
24
+ feat_boxes = []
25
+ for idx_box in range(0, boxes.shape[0]):
26
+ y_tl, x_tl, y_br, x_br = boxes_scaled[idx_box]
27
+ y_tl, x_tl, y_br, x_br = int(y_tl), int(x_tl), int(y_br), int(x_br)
28
+ feat_box = feat[:, :, y_tl : (y_br + 1), x_tl : (x_br + 1)]
29
+ feat_boxes.append(feat_box)
30
+ return feat_boxes
31
+
32
+ class Counting_with_SD_features(nn.Module):
33
+ def __init__(self, scale_factor):
34
+ super(Counting_with_SD_features, self).__init__()
35
+ self.adapter = adapter_roi()
36
+ # self.regressor = regressor_with_SD_features()
37
+
38
+ class Counting_with_SD_features_loca(nn.Module):
39
+ def __init__(self, scale_factor):
40
+ super(Counting_with_SD_features_loca, self).__init__()
41
+ self.adapter = adapter_roi_loca()
42
+ self.regressor = regressor_with_SD_features()
43
+
44
+
45
+ class Counting_with_SD_features_dino_vit_c3(nn.Module):
46
+ def __init__(self, scale_factor, vit=None):
47
+ super(Counting_with_SD_features_dino_vit_c3, self).__init__()
48
+ self.adapter = adapter_roi_loca()
49
+ self.regressor = regressor_with_SD_features_seg_vit_c3()
50
+
51
+ class Counting_with_SD_features_track(nn.Module):
52
+ def __init__(self, scale_factor, vit=None):
53
+ super(Counting_with_SD_features_track, self).__init__()
54
+ self.adapter = adapter_roi_loca()
55
+ self.regressor = regressor_with_SD_features_tra()
56
+
57
+
58
+ class adapter_roi(nn.Module):
59
+ def __init__(self, pool_size=[3, 3]):
60
+ super(adapter_roi, self).__init__()
61
+ self.pool_size = pool_size
62
+ self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
63
+ # self.relu = nn.ReLU()
64
+ # self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
65
+ self.pool = nn.MaxPool2d(2)
66
+ self.fc = nn.Linear(256 * 3 * 3, 768)
67
+ # **new
68
+ self.fc1 = nn.Sequential(
69
+ nn.ReLU(),
70
+ nn.Linear(768, 768 // 4, bias=False),
71
+ nn.ReLU()
72
+ )
73
+ self.fc2 = nn.Sequential(
74
+ nn.Linear(768 // 4, 768, bias=False),
75
+ # nn.ReLU()
76
+ )
77
+ self.initialize_weights()
78
+
79
+ def forward(self, x, boxes):
80
+ num_of_boxes = boxes.shape[1]
81
+ rois = []
82
+ bs, _, h, w = x.shape
83
+ boxes = torch.cat([
84
+ torch.arange(
85
+ bs, requires_grad=False
86
+ ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
87
+ boxes.flatten(0, 1),
88
+ ], dim=1)
89
+ rois = roi_align(
90
+ x,
91
+ boxes=boxes, output_size=3,
92
+ spatial_scale=1.0 / 8, aligned=True
93
+ )
94
+ rois = torch.mean(rois, dim=0, keepdim=True)
95
+ x = self.conv1(rois)
96
+ x = x.view(x.size(0), -1)
97
+ x = self.fc(x)
98
+
99
+ x = self.fc1(x)
100
+ x = self.fc2(x)
101
+ return x
102
+
103
+
104
+ def initialize_weights(self):
105
+ for m in self.modules():
106
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
107
+ nn.init.xavier_normal_(m.weight)
108
+ if m.bias is not None:
109
+ nn.init.constant_(m.bias, 0)
110
+
111
+
112
+ class adapter_roi_loca(nn.Module):
113
+ def __init__(self, pool_size=[3, 3]):
114
+ super(adapter_roi_loca, self).__init__()
115
+ self.pool_size = pool_size
116
+ self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
117
+ self.pool = nn.MaxPool2d(2)
118
+ self.fc = nn.Linear(256 * 3 * 3, 768)
119
+ self.initialize_weights()
120
+ def forward(self, x, boxes):
121
+ num_of_boxes = boxes.shape[1]
122
+ rois = []
123
+ bs, _, h, w = x.shape
124
+ if h != 512 or w != 512:
125
+ x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False)
126
+ if bs == 1:
127
+ boxes = torch.cat([
128
+ torch.arange(
129
+ bs, requires_grad=False
130
+ ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
131
+ boxes.flatten(0, 1),
132
+ ], dim=1)
133
+ rois = roi_align(
134
+ x,
135
+ boxes=boxes, output_size=3,
136
+ spatial_scale=1.0 / 8, aligned=True
137
+ )
138
+ rois = torch.mean(rois, dim=0, keepdim=True)
139
+ else:
140
+ boxes = torch.cat([
141
+ boxes.flatten(0, 1),
142
+ ], dim=1).split(num_of_boxes, dim=0)
143
+ rois = roi_align(
144
+ x,
145
+ boxes=boxes, output_size=3,
146
+ spatial_scale=1.0 / 8, aligned=True
147
+ )
148
+ rois = rois.split(num_of_boxes, dim=0)
149
+ rois = torch.stack(rois, dim=0)
150
+ rois = torch.mean(rois, dim=1, keepdim=False)
151
+ x = self.conv1(rois)
152
+ x = x.view(x.size(0), -1)
153
+ x = self.fc(x)
154
+ return x
155
+
156
+ def forward_boxes(self, x, boxes):
157
+ num_of_boxes = boxes.shape[1]
158
+ rois = []
159
+ bs, _, h, w = x.shape
160
+ if h != 512 or w != 512:
161
+ x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False)
162
+ if bs == 1:
163
+ boxes = torch.cat([
164
+ torch.arange(
165
+ bs, requires_grad=False
166
+ ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
167
+ boxes.flatten(0, 1),
168
+ ], dim=1)
169
+ rois = roi_align(
170
+ x,
171
+ boxes=boxes, output_size=3,
172
+ spatial_scale=1.0 / 8, aligned=True
173
+ )
174
+ # rois = torch.mean(rois, dim=0, keepdim=True)
175
+ else:
176
+ raise NotImplementedError
177
+ x = self.conv1(rois)
178
+ x = x.view(x.size(0), -1)
179
+ x = self.fc(x)
180
+ return x
181
+
182
+ def initialize_weights(self):
183
+ for m in self.modules():
184
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
185
+ nn.init.xavier_normal_(m.weight)
186
+ if m.bias is not None:
187
+ nn.init.constant_(m.bias, 0)
188
+
189
+
190
+
191
+
192
+ class regressor1(nn.Module):
193
+ def __init__(self):
194
+ super(regressor1, self).__init__()
195
+ self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
196
+ self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
197
+ self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
198
+ self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
199
+ self.leaky_relu = nn.LeakyReLU()
200
+ self.relu = nn.ReLU()
201
+ self.initialize_weights()
202
+
203
+
204
+
205
+ def forward(self, x):
206
+ x_ = self.conv1(x)
207
+ x_ = self.leaky_relu(x_)
208
+ x_ = self.upsampler(x_)
209
+ x_ = self.conv2(x_)
210
+ x_ = self.leaky_relu(x_)
211
+ x_ = self.upsampler(x_)
212
+ x_ = self.conv3(x_)
213
+ x_ = self.relu(x_)
214
+ out = x_
215
+ return out
216
+
217
+ def initialize_weights(self):
218
+ for m in self.modules():
219
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
220
+ nn.init.xavier_normal_(m.weight)
221
+ if m.bias is not None:
222
+ nn.init.constant_(m.bias, 0)
223
+
224
+
225
+ class regressor1(nn.Module):
226
+ def __init__(self):
227
+ super(regressor1, self).__init__()
228
+ self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
229
+ self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
230
+ self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
231
+ self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
232
+ self.leaky_relu = nn.LeakyReLU()
233
+ self.relu = nn.ReLU()
234
+
235
+ def forward(self, x):
236
+ x_ = self.conv1(x)
237
+ x_ = self.leaky_relu(x_)
238
+ x_ = self.upsampler(x_)
239
+ x_ = self.conv2(x_)
240
+ x_ = self.leaky_relu(x_)
241
+ x_ = self.upsampler(x_)
242
+ x_ = self.conv3(x_)
243
+ x_ = self.relu(x_)
244
+ out = x_
245
+ return out
246
+ def initialize_weights(self):
247
+ for m in self.modules():
248
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
249
+ nn.init.xavier_normal_(m.weight)
250
+ if m.bias is not None:
251
+ nn.init.constant_(m.bias, 0)
252
+
253
+
254
+ class regressor_with_SD_features(nn.Module):
255
+ def __init__(self):
256
+ super(regressor_with_SD_features, self).__init__()
257
+ self.layer1 = nn.Sequential(
258
+ nn.Conv2d(324, 256, kernel_size=1, stride=1),
259
+ nn.LeakyReLU(),
260
+ nn.LayerNorm((64, 64))
261
+ )
262
+ self.layer2 = nn.Sequential(
263
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
264
+ nn.LeakyReLU(),
265
+ nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
266
+ )
267
+ self.layer3 = nn.Sequential(
268
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
269
+ nn.ReLU(),
270
+ nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
271
+ )
272
+ self.layer4 = nn.Sequential(
273
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
274
+ nn.LeakyReLU(),
275
+ nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
276
+ )
277
+ self.conv = nn.Sequential(
278
+ nn.Conv2d(32, 1, kernel_size=1),
279
+ nn.ReLU()
280
+ )
281
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
282
+ self.initialize_weights()
283
+
284
+ def forward(self, attn_stack, feature_list):
285
+ attn_stack = self.norm(attn_stack)
286
+ unet_feature = feature_list[-1]
287
+ attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
288
+ unet_feature = unet_feature * attn_stack_mean
289
+ unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
290
+ x = self.layer1(unet_feature)
291
+ x = self.layer2(x)
292
+ x = self.layer3(x)
293
+ x = self.layer4(x)
294
+ out = self.conv(x)
295
+ return out / 100
296
+
297
+ def initialize_weights(self):
298
+ for m in self.modules():
299
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
300
+ nn.init.xavier_normal_(m.weight)
301
+ if m.bias is not None:
302
+ nn.init.constant_(m.bias, 0)
303
+
304
+ class regressor_with_SD_features_seg(nn.Module):
305
+ def __init__(self):
306
+ super(regressor_with_SD_features_seg, self).__init__()
307
+ self.layer1 = nn.Sequential(
308
+ nn.Conv2d(324, 256, kernel_size=1, stride=1),
309
+ nn.LeakyReLU(),
310
+ nn.LayerNorm((64, 64))
311
+ )
312
+ self.layer2 = nn.Sequential(
313
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
314
+ nn.LeakyReLU(),
315
+ nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
316
+ )
317
+ self.layer3 = nn.Sequential(
318
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
319
+ nn.ReLU(),
320
+ nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
321
+ )
322
+ self.layer4 = nn.Sequential(
323
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
324
+ nn.LeakyReLU(),
325
+ nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
326
+ )
327
+ self.conv = nn.Sequential(
328
+ nn.Conv2d(32, 2, kernel_size=1),
329
+ # nn.ReLU()
330
+ )
331
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
332
+ self.initialize_weights()
333
+
334
+ def forward(self, attn_stack, feature_list):
335
+ attn_stack = self.norm(attn_stack)
336
+ unet_feature = feature_list[-1]
337
+ attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
338
+ unet_feature = unet_feature * attn_stack_mean
339
+ unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
340
+ x = self.layer1(unet_feature)
341
+ x = self.layer2(x)
342
+ x = self.layer3(x)
343
+ x = self.layer4(x)
344
+ out = self.conv(x)
345
+ return out
346
+
347
+ def initialize_weights(self):
348
+ for m in self.modules():
349
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
350
+ nn.init.xavier_normal_(m.weight)
351
+ if m.bias is not None:
352
+ nn.init.constant_(m.bias, 0)
353
+
354
+
355
+ from models.enc_model.unet_parts import *
356
+
357
+
358
+ class regressor_with_SD_features_seg_vit_c3(nn.Module):
359
+ def __init__(self, n_channels=3, n_classes=2, bilinear=False):
360
+ super(regressor_with_SD_features_seg_vit_c3, self).__init__()
361
+ self.n_channels = n_channels
362
+ self.n_classes = n_classes
363
+ self.bilinear = bilinear
364
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
365
+ self.inc_0 = nn.Conv2d(n_channels, 3, kernel_size=3, padding=1)
366
+ self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
367
+ self.vit = self.vit_model.net
368
+
369
+ def forward(self, img, attn_stack, feature_list):
370
+ attn_stack = attn_stack[:, [1,3], ...]
371
+ attn_stack = self.norm(attn_stack)
372
+ unet_feature = feature_list[-1]
373
+ unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
374
+
375
+ x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
376
+
377
+ if x.shape[-1] != 512:
378
+ x = F.interpolate(x, size=(512, 512), mode="bilinear")
379
+ x = self.inc_0(x)
380
+
381
+
382
+
383
+ out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
384
+ if out.dtype == np.uint16:
385
+ out = out.astype(np.int16)
386
+ out = torch.from_numpy(out).unsqueeze(0).to(x.device)
387
+ return out
388
+
389
+ def initialize_weights(self):
390
+ for m in self.modules():
391
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
392
+ nn.init.xavier_normal_(m.weight)
393
+ if m.bias is not None:
394
+ nn.init.constant_(m.bias, 0)
395
+
396
+ class regressor_with_SD_features_tra(nn.Module):
397
+ def __init__(self, n_channels=2, n_classes=2, bilinear=False):
398
+ super(regressor_with_SD_features_tra, self).__init__()
399
+ self.n_channels = n_channels
400
+ self.n_classes = n_classes
401
+ self.bilinear = bilinear
402
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
403
+
404
+ # segmentation
405
+ self.inc_0 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
406
+ self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
407
+ self.vit = self.vit_model.net
408
+
409
+ self.inc_1 = nn.Conv2d(n_channels, 1, kernel_size=3, padding=1)
410
+ self.mlp = nn.Linear(64 * 64, 320)
411
+ # self.vit = self.vit_model.net.float()
412
+
413
+ def forward_seg(self, img, attn_stack, feature_list, mask, training=False):
414
+ attn_stack = attn_stack[:, [1,3], ...]
415
+ attn_stack = self.norm(attn_stack)
416
+ unet_feature = feature_list[-1]
417
+ unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
418
+ x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
419
+
420
+ if x.shape[-1] != 512:
421
+ x = F.interpolate(x, size=(512, 512), mode="bilinear")
422
+ x = self.inc_0(x)
423
+ feat = x
424
+
425
+ out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
426
+ if out.dtype == np.uint16:
427
+ out = out.astype(np.int16)
428
+ out = torch.from_numpy(out).unsqueeze(0).to(x.device)
429
+ return out, 0., feat
430
+
431
+ def forward(self, attn_prev, feature_list_prev, attn_after, feature_list_after):
432
+ assert attn_prev.shape == attn_after.shape, "attn_prev and attn_after must have the same shape"
433
+ n_instances = attn_prev.shape[0]
434
+ attn_prev = self.norm(attn_prev) # [n_instances, 1, 64, 64]
435
+ attn_after = self.norm(attn_after)
436
+
437
+ x = torch.cat([attn_prev, attn_after], dim=1) # n_instances, 2, 64, 64
438
+
439
+ x = self.inc_1(x)
440
+ x = x.view(1, n_instances, -1) # Flatten the tensor to [n_instances, 64*64*4]
441
+ x = self.mlp(x) # Apply the MLP to get the output
442
+
443
+ return x # Output shape will be [n_instances, 4]
444
+
445
+
446
+
447
+ def initialize_weights(self):
448
+ for m in self.modules():
449
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
450
+ nn.init.xavier_normal_(m.weight)
451
+ if m.bias is not None:
452
+ nn.init.constant_(m.bias, 0)
453
+
454
+
455
+
456
+ class regressor_with_SD_features_inst_seg_unet(nn.Module):
457
+ def __init__(self, n_channels=8, n_classes=3, bilinear=False):
458
+ super(regressor_with_SD_features_inst_seg_unet, self).__init__()
459
+ self.n_channels = n_channels
460
+ self.n_classes = n_classes
461
+ self.bilinear = bilinear
462
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
463
+ self.inc_0 = (DoubleConv(n_channels, 3))
464
+ self.inc = (DoubleConv(3, 64))
465
+ self.down1 = (Down(64, 128))
466
+ self.down2 = (Down(128, 256))
467
+ self.down3 = (Down(256, 512))
468
+ factor = 2 if bilinear else 1
469
+ self.down4 = (Down(512, 1024 // factor))
470
+ self.up1 = (Up(1024, 512 // factor, bilinear))
471
+ self.up2 = (Up(512, 256 // factor, bilinear))
472
+ self.up3 = (Up(256, 128 // factor, bilinear))
473
+ self.up4 = (Up(128, 64, bilinear))
474
+ self.outc = (OutConv(64, n_classes))
475
+
476
+ def forward(self, img, attn_stack, feature_list):
477
+ attn_stack = self.norm(attn_stack)
478
+ unet_feature = feature_list[-1]
479
+ unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
480
+ attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
481
+ unet_feature_mean = unet_feature_mean * attn_stack_mean
482
+ x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
483
+ if x.shape[-1] != 512:
484
+ x = F.interpolate(x, size=(512, 512), mode="bilinear")
485
+ x = torch.cat([img, x], dim=1) # [1, 8, 512, 512]
486
+ x = self.inc_0(x)
487
+ x1 = self.inc(x)
488
+ x2 = self.down1(x1)
489
+ x3 = self.down2(x2)
490
+ x4 = self.down3(x3)
491
+ x5 = self.down4(x4)
492
+ x = self.up1(x5, x4)
493
+ x = self.up2(x, x3)
494
+ x = self.up3(x, x2)
495
+ x = self.up4(x, x1)
496
+ out = self.outc(x)
497
+ return out
498
+
499
+ def initialize_weights(self):
500
+ for m in self.modules():
501
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
502
+ nn.init.xavier_normal_(m.weight)
503
+ if m.bias is not None:
504
+ nn.init.constant_(m.bias, 0)
505
+
506
+
507
+ class regressor_with_SD_features_self(nn.Module):
508
+ def __init__(self):
509
+ super(regressor_with_SD_features_self, self).__init__()
510
+ self.layer = nn.Sequential(
511
+ nn.Conv2d(4096, 1024, kernel_size=1, stride=1),
512
+ nn.LeakyReLU(),
513
+ nn.LayerNorm((64, 64)),
514
+ nn.Conv2d(1024, 256, kernel_size=1, stride=1),
515
+ nn.LeakyReLU(),
516
+ nn.LayerNorm((64, 64)),
517
+ )
518
+ self.layer2 = nn.Sequential(
519
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
520
+ nn.LeakyReLU(),
521
+ nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
522
+ )
523
+ self.layer3 = nn.Sequential(
524
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
525
+ nn.ReLU(),
526
+ nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
527
+ )
528
+ self.layer4 = nn.Sequential(
529
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
530
+ nn.LeakyReLU(),
531
+ nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
532
+ )
533
+ self.conv = nn.Sequential(
534
+ nn.Conv2d(32, 1, kernel_size=1),
535
+ nn.ReLU()
536
+ )
537
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
538
+ self.initialize_weights()
539
+
540
+ def forward(self, self_attn):
541
+ self_attn = self_attn.permute(2, 0, 1)
542
+ self_attn = self.layer(self_attn)
543
+ return self_attn
544
+ # attn_stack = self.norm(attn_stack)
545
+ # unet_feature = feature_list[-1]
546
+ # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
547
+ # unet_feature = unet_feature * attn_stack_mean
548
+ # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
549
+ # x = self.layer(unet_feature)
550
+ # x = self.layer2(x)
551
+ # x = self.layer3(x)
552
+ # x = self.layer4(x)
553
+ # out = self.conv(x)
554
+ # return out / 100
555
+
556
+ def initialize_weights(self):
557
+ for m in self.modules():
558
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
559
+ nn.init.xavier_normal_(m.weight)
560
+ if m.bias is not None:
561
+ nn.init.constant_(m.bias, 0)
562
+
563
+
564
+ class regressor_with_SD_features_latent(nn.Module):
565
+ def __init__(self):
566
+ super(regressor_with_SD_features_latent, self).__init__()
567
+ self.layer = nn.Sequential(
568
+ nn.Conv2d(4, 256, kernel_size=1, stride=1),
569
+ nn.LeakyReLU(),
570
+ nn.LayerNorm((64, 64))
571
+ )
572
+ self.layer2 = nn.Sequential(
573
+ nn.Conv2d(256, 128, kernel_size=3, padding=1),
574
+ nn.LeakyReLU(),
575
+ nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
576
+ )
577
+ self.layer3 = nn.Sequential(
578
+ nn.Conv2d(128, 64, kernel_size=3, padding=1),
579
+ nn.ReLU(),
580
+ nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
581
+ )
582
+ self.layer4 = nn.Sequential(
583
+ nn.Conv2d(64, 32, kernel_size=3, padding=1),
584
+ nn.LeakyReLU(),
585
+ nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
586
+ )
587
+ self.conv = nn.Sequential(
588
+ nn.Conv2d(32, 1, kernel_size=1),
589
+ nn.ReLU()
590
+ )
591
+ self.norm = nn.LayerNorm(normalized_shape=(64, 64))
592
+ self.initialize_weights()
593
+
594
+ def forward(self, self_attn):
595
+ # self_attn = self_attn.permute(2, 0, 1)
596
+ self_attn = self.layer(self_attn)
597
+ return self_attn
598
+ # attn_stack = self.norm(attn_stack)
599
+ # unet_feature = feature_list[-1]
600
+ # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
601
+ # unet_feature = unet_feature * attn_stack_mean
602
+ # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
603
+ # x = self.layer(unet_feature)
604
+ # x = self.layer2(x)
605
+ # x = self.layer3(x)
606
+ # x = self.layer4(x)
607
+ # out = self.conv(x)
608
+ # return out / 100
609
+
610
+ def initialize_weights(self):
611
+ for m in self.modules():
612
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
613
+ nn.init.xavier_normal_(m.weight)
614
+ if m.bias is not None:
615
+ nn.init.constant_(m.bias, 0)
616
+
617
+
618
+
619
+
620
+
621
+ class regressor_with_deconv(nn.Module):
622
+ def __init__(self):
623
+ super(regressor_with_deconv, self).__init__()
624
+ self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
625
+ self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
626
+ self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
627
+ self.deconv1 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
628
+ self.deconv2 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
629
+ self.leaky_relu = nn.LeakyReLU()
630
+ self.relu = nn.ReLU()
631
+ self.initialize_weights()
632
+
633
+ def forward(self, x):
634
+ x_ = self.conv1(x)
635
+ x_ = self.leaky_relu(x_)
636
+ x_ = self.deconv1(x_)
637
+ x_ = self.conv2(x_)
638
+ x_ = self.leaky_relu(x_)
639
+ x_ = self.deconv2(x_)
640
+ x_ = self.conv3(x_)
641
+ x_ = self.relu(x_)
642
+ out = x_
643
+ return out
644
+
645
+ def initialize_weights(self):
646
+ for m in self.modules():
647
+ if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
648
+ nn.init.xavier_normal_(m.weight)
649
+ if m.bias is not None:
650
+ nn.init.constant_(m.bias, 0)
651
+
652
+
653
+
models/seg_post_model/cellpose/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .version import version, version_str
models/seg_post_model/cellpose/__main__.py ADDED
@@ -0,0 +1,272 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import os, time
5
+ import numpy as np
6
+ from tqdm import tqdm
7
+ from cellpose import utils, models, io, train
8
+ from .version import version_str
9
+ from cellpose.cli import get_arg_parser
10
+
11
+ try:
12
+ from cellpose.gui import gui3d, gui
13
+ GUI_ENABLED = True
14
+ except ImportError as err:
15
+ GUI_ERROR = err
16
+ GUI_ENABLED = False
17
+ GUI_IMPORT = True
18
+ except Exception as err:
19
+ GUI_ENABLED = False
20
+ GUI_ERROR = err
21
+ GUI_IMPORT = False
22
+ raise
23
+
24
+ import logging
25
+
26
+
27
+ def main():
28
+ """ Run cellpose from command line
29
+ """
30
+
31
+ args = get_arg_parser().parse_args() # this has to be in a separate file for autodoc to work
32
+
33
+ if args.version:
34
+ print(version_str)
35
+ return
36
+
37
+ ######## if no image arguments are provided, run GUI or add model and exit ########
38
+ if len(args.dir) == 0 and len(args.image_path) == 0:
39
+ if args.add_model:
40
+ io.add_model(args.add_model)
41
+ return
42
+ else:
43
+ if not GUI_ENABLED:
44
+ print("GUI ERROR: %s" % GUI_ERROR)
45
+ if GUI_IMPORT:
46
+ print(
47
+ "GUI FAILED: GUI dependencies may not be installed, to install, run"
48
+ )
49
+ print(" pip install 'cellpose[gui]'")
50
+ else:
51
+ if args.Zstack:
52
+ gui3d.run()
53
+ else:
54
+ gui.run()
55
+ return
56
+
57
+ ############################## run cellpose on images ##############################
58
+ if args.verbose:
59
+ from .io import logger_setup
60
+ logger, log_file = logger_setup()
61
+ else:
62
+ print(
63
+ ">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose")
64
+ print("No --verbose => no progress or info printed")
65
+ logger = logging.getLogger(__name__)
66
+
67
+
68
+ # find images
69
+ if len(args.img_filter) > 0:
70
+ image_filter = args.img_filter
71
+ else:
72
+ image_filter = None
73
+
74
+ device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu,
75
+ device=args.gpu_device)
76
+
77
+ if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0":
78
+ pretrained_model = "cpsam"
79
+ logger.warning("training from scratch is disabled, using 'cpsam' model")
80
+ else:
81
+ pretrained_model = args.pretrained_model
82
+
83
+ # Warn users about old arguments from CP3:
84
+ if args.pretrained_model_ortho:
85
+ logger.warning(
86
+ "the '--pretrained_model_ortho' flag is deprecated in v4.0.1+ and no longer used")
87
+ if args.train_size:
88
+ logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
89
+ if args.chan or args.chan2:
90
+ logger.warning('--chan and --chan2 are deprecated, all channels are used by default')
91
+ if args.all_channels:
92
+ logger.warning("the '--all_channels' flag is deprecated in v4.0.1+ and no longer used")
93
+ if args.restore_type:
94
+ logger.warning("the '--restore_type' flag is deprecated in v4.0.1+ and no longer used")
95
+ if args.transformer:
96
+ logger.warning("the '--tranformer' flag is deprecated in v4.0.1+ and no longer used")
97
+ if args.invert:
98
+ logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used")
99
+ if args.chan2_restore:
100
+ logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used")
101
+ if args.diam_mean:
102
+ logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used")
103
+ if args.train_size:
104
+ logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
105
+
106
+ if args.norm_percentile is not None:
107
+ value1, value2 = args.norm_percentile
108
+ normalize = {'percentile': (float(value1), float(value2))}
109
+ else:
110
+ normalize = (not args.no_norm)
111
+
112
+ if args.save_each:
113
+ if not args.save_every:
114
+ raise ValueError("ERROR: --save_each requires --save_every")
115
+
116
+ if len(args.image_path) > 0 and args.train:
117
+ raise ValueError("ERROR: cannot train model with single image input")
118
+
119
+ ## Run evaluation on images
120
+ if not args.train:
121
+ _evaluate_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
122
+
123
+ ## Train a model ##
124
+ else:
125
+ _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
126
+
127
+
128
+ def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize):
129
+ test_dir = None if len(args.test_dir) == 0 else args.test_dir
130
+ images, labels, image_names, train_probs = None, None, None, None
131
+ test_images, test_labels, image_names_test, test_probs = None, None, None, None
132
+ compute_flows = False
133
+ if len(args.file_list) > 0:
134
+ if os.path.exists(args.file_list):
135
+ dat = np.load(args.file_list, allow_pickle=True).item()
136
+ image_names = dat["train_files"]
137
+ image_names_test = dat.get("test_files", None)
138
+ train_probs = dat.get("train_probs", None)
139
+ test_probs = dat.get("test_probs", None)
140
+ compute_flows = dat.get("compute_flows", False)
141
+ load_files = False
142
+ else:
143
+ logger.critical(f"ERROR: {args.file_list} does not exist")
144
+ else:
145
+ output = io.load_train_test_data(args.dir, test_dir, image_filter,
146
+ args.mask_filter,
147
+ args.look_one_level_down)
148
+ images, labels, image_names, test_images, test_labels, image_names_test = output
149
+ load_files = True
150
+
151
+ # initialize model
152
+ model = models.CellposeModel(device=device, pretrained_model=pretrained_model)
153
+
154
+ # train segmentation model
155
+ cpmodel_path = train.train_seg(
156
+ model.net, images, labels, train_files=image_names,
157
+ test_data=test_images, test_labels=test_labels,
158
+ test_files=image_names_test, train_probs=train_probs,
159
+ test_probs=test_probs, compute_flows=compute_flows,
160
+ load_files=load_files, normalize=normalize,
161
+ channel_axis=args.channel_axis,
162
+ learning_rate=args.learning_rate, weight_decay=args.weight_decay,
163
+ SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.train_batch_size,
164
+ min_train_masks=args.min_train_masks,
165
+ nimg_per_epoch=args.nimg_per_epoch,
166
+ nimg_test_per_epoch=args.nimg_test_per_epoch,
167
+ save_path=os.path.realpath(args.dir),
168
+ save_every=args.save_every,
169
+ save_each=args.save_each,
170
+ model_name=args.model_name_out)[0]
171
+ model.pretrained_model = cpmodel_path
172
+ logger.info(">>>> model trained and saved to %s" % cpmodel_path)
173
+ return model
174
+
175
+
176
+ def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, normalize):
177
+ # Check with user if they REALLY mean to run without saving anything
178
+ if not args.train:
179
+ saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt
180
+
181
+ tic = time.time()
182
+ if len(args.dir) > 0:
183
+ image_names = io.get_image_files(
184
+ args.dir, args.mask_filter, imf=imf,
185
+ look_one_level_down=args.look_one_level_down)
186
+ else:
187
+ if os.path.exists(args.image_path):
188
+ image_names = [args.image_path]
189
+ else:
190
+ raise ValueError(f"ERROR: no file found at {args.image_path}")
191
+ nimg = len(image_names)
192
+
193
+ if args.savedir:
194
+ if not os.path.exists(args.savedir):
195
+ raise FileExistsError(f"--savedir {args.savedir} does not exist")
196
+
197
+ logger.info(
198
+ ">>>> running cellpose on %d images using all channels" % nimg)
199
+
200
+ # handle built-in model exceptions
201
+ model = models.CellposeModel(device=device, pretrained_model=pretrained_model,)
202
+
203
+ tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
204
+
205
+ channel_axis = args.channel_axis
206
+ z_axis = args.z_axis
207
+
208
+ for image_name in tqdm(image_names, file=tqdm_out):
209
+ if args.do_3D or args.stitch_threshold > 0.:
210
+ logger.info('loading image as 3D zstack')
211
+ image = io.imread_3D(image_name)
212
+ if channel_axis is None:
213
+ channel_axis = 3
214
+ if z_axis is None:
215
+ z_axis = 0
216
+
217
+ else:
218
+ image = io.imread_2D(image_name)
219
+ out = model.eval(
220
+ image,
221
+ diameter=args.diameter,
222
+ do_3D=args.do_3D,
223
+ augment=args.augment,
224
+ flow_threshold=args.flow_threshold,
225
+ cellprob_threshold=args.cellprob_threshold,
226
+ stitch_threshold=args.stitch_threshold,
227
+ min_size=args.min_size,
228
+ batch_size=args.batch_size,
229
+ bsize=args.bsize,
230
+ resample=not args.no_resample,
231
+ normalize=normalize,
232
+ channel_axis=channel_axis,
233
+ z_axis=z_axis,
234
+ anisotropy=args.anisotropy,
235
+ niter=args.niter,
236
+ flow3D_smooth=args.flow3D_smooth)
237
+ masks, flows = out[:2]
238
+
239
+ if args.exclude_on_edges:
240
+ masks = utils.remove_edge_masks(masks)
241
+ if not args.no_npy:
242
+ io.masks_flows_to_seg(image, masks, flows, image_name,
243
+ imgs_restore=None,
244
+ restore_type=None,
245
+ ratio=1.)
246
+ if saving_something:
247
+ suffix = "_cp_masks"
248
+ if args.output_name is not None:
249
+ # (1) If `savedir` is not defined, then must have a non-zero `suffix`
250
+ if args.savedir is None and len(args.output_name) > 0:
251
+ suffix = args.output_name
252
+ elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir):
253
+ # (2) If `savedir` is defined, and different from `dir` then
254
+ # takes the value passed as a param. (which can be empty string)
255
+ suffix = args.output_name
256
+
257
+ io.save_masks(image, masks, flows, image_name,
258
+ suffix=suffix, png=args.save_png,
259
+ tif=args.save_tif, save_flows=args.save_flows,
260
+ save_outlines=args.save_outlines,
261
+ dir_above=args.dir_above, savedir=args.savedir,
262
+ save_txt=args.save_txt, in_folders=args.in_folders,
263
+ save_mpl=args.save_mpl)
264
+ if args.save_rois:
265
+ io.save_rois(masks, image_name)
266
+ logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))
267
+
268
+ return model
269
+
270
+
271
+ if __name__ == "__main__":
272
+ main()
models/seg_post_model/cellpose/cli.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden.
3
+ """
4
+
5
+ import argparse
6
+
7
+
8
+ def get_arg_parser():
9
+ """ Parses command line arguments for cellpose main function
10
+
11
+ Note: this function has to be in a separate file to allow autodoc to work for CLI.
12
+ The autodoc_mock_imports in conf.py does not work for sphinx-argparse sometimes,
13
+ see https://github.com/ashb/sphinx-argparse/issues/9#issue-1097057823
14
+ """
15
+
16
+ parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters")
17
+
18
+ # misc settings
19
+ parser.add_argument("--version", action="store_true",
20
+ help="show cellpose version info")
21
+ parser.add_argument(
22
+ "--verbose", action="store_true",
23
+ help="show information about running and settings and save to log")
24
+ parser.add_argument("--Zstack", action="store_true", help="run GUI in 3D mode")
25
+
26
+ # settings for CPU vs GPU
27
+ hardware_args = parser.add_argument_group("Hardware Arguments")
28
+ hardware_args.add_argument("--use_gpu", action="store_true",
29
+ help="use gpu if torch with cuda installed")
30
+ hardware_args.add_argument(
31
+ "--gpu_device", required=False, default="0", type=str,
32
+ help="which gpu device to use, use an integer for torch, or mps for M1")
33
+
34
+ # settings for locating and formatting images
35
+ input_img_args = parser.add_argument_group("Input Image Arguments")
36
+ input_img_args.add_argument("--dir", default=[], type=str,
37
+ help="folder containing data to run or train on.")
38
+ input_img_args.add_argument(
39
+ "--image_path", default=[], type=str, help=
40
+ "if given and --dir not given, run on single image instead of folder (cannot train with this option)"
41
+ )
42
+ input_img_args.add_argument(
43
+ "--look_one_level_down", action="store_true",
44
+ help="run processing on all subdirectories of current folder")
45
+ input_img_args.add_argument("--img_filter", default=[], type=str,
46
+ help="end string for images to run on")
47
+ input_img_args.add_argument(
48
+ "--channel_axis", default=None, type=int,
49
+ help="axis of image which corresponds to image channels")
50
+ input_img_args.add_argument("--z_axis", default=None, type=int,
51
+ help="axis of image which corresponds to Z dimension")
52
+
53
+ # TODO: remove deprecated in future version
54
+ input_img_args.add_argument(
55
+ "--chan", default=0, type=int, help=
56
+ "Deprecated in v4.0.1+, not used. ")
57
+ input_img_args.add_argument(
58
+ "--chan2", default=0, type=int, help=
59
+ 'Deprecated in v4.0.1+, not used. ')
60
+ input_img_args.add_argument("--invert", action="store_true", help=
61
+ 'Deprecated in v4.0.1+, not used. ')
62
+ input_img_args.add_argument(
63
+ "--all_channels", action="store_true", help=
64
+ 'Deprecated in v4.0.1+, not used. ')
65
+
66
+ # model settings
67
+ model_args = parser.add_argument_group("Model Arguments")
68
+ model_args.add_argument("--pretrained_model", required=False, default="cpsam",
69
+ type=str,
70
+ help="model to use for running or starting training")
71
+ model_args.add_argument(
72
+ "--add_model", required=False, default=None, type=str,
73
+ help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
74
+ model_args.add_argument("--pretrained_model_ortho", required=False, default=None,
75
+ type=str,
76
+ help="Deprecated in v4.0.1+, not used. ")
77
+
78
+ # TODO: remove deprecated in future version
79
+ model_args.add_argument("--restore_type", required=False, default=None, type=str, help=
80
+ 'Deprecated in v4.0.1+, not used. ')
81
+ model_args.add_argument("--chan2_restore", action="store_true", help=
82
+ 'Deprecated in v4.0.1+, not used. ')
83
+ model_args.add_argument(
84
+ "--transformer", action="store_true", help=
85
+ "use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)")
86
+
87
+ # algorithm settings
88
+ algorithm_args = parser.add_argument_group("Algorithm Arguments")
89
+ algorithm_args.add_argument("--no_norm", action="store_true",
90
+ help="do not normalize images (normalize=False)")
91
+ algorithm_args.add_argument(
92
+ '--norm_percentile',
93
+ nargs=2, # Require exactly two values
94
+ metavar=('VALUE1', 'VALUE2'),
95
+ help="Provide two float values to set norm_percentile (e.g., --norm_percentile 1 99)"
96
+ )
97
+ algorithm_args.add_argument(
98
+ "--do_3D", action="store_true",
99
+ help="process images as 3D stacks of images (nplanes x nchan x Ly x Lx")
100
+ algorithm_args.add_argument(
101
+ "--diameter", required=False, default=None, type=float, help=
102
+ "use to resize cells to the training diameter (30 pixels)"
103
+ )
104
+ algorithm_args.add_argument(
105
+ "--stitch_threshold", required=False, default=0.0, type=float,
106
+ help="compute masks in 2D then stitch together masks with IoU>0.9 across planes"
107
+ )
108
+ algorithm_args.add_argument(
109
+ "--min_size", required=False, default=15, type=int,
110
+ help="minimum number of pixels per mask, can turn off with -1")
111
+ algorithm_args.add_argument(
112
+ "--flow3D_smooth", required=False, default=0, type=float,
113
+ help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing")
114
+ algorithm_args.add_argument(
115
+ "--flow_threshold", default=0.4, type=float, help=
116
+ "flow error threshold, 0 turns off this optional QC step. Default: %(default)s")
117
+ algorithm_args.add_argument(
118
+ "--cellprob_threshold", default=0, type=float,
119
+ help="cellprob threshold, default is 0, decrease to find more and larger masks")
120
+ algorithm_args.add_argument(
121
+ "--niter", default=0, type=int, help=
122
+ "niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs"
123
+ )
124
+ algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
125
+ help="anisotropy of volume in 3D")
126
+ algorithm_args.add_argument("--exclude_on_edges", action="store_true",
127
+ help="discard masks which touch edges of image")
128
+ algorithm_args.add_argument(
129
+ "--augment", action="store_true",
130
+ help="tiles image with overlapping tiles and flips overlapped regions to augment"
131
+ )
132
+ algorithm_args.add_argument("--batch_size", default=8, type=int,
133
+ help="inference batch size. Default: %(default)s")
134
+
135
+ # TODO: remove deprecated in future version
136
+ algorithm_args.add_argument(
137
+ "--no_resample", action="store_true",
138
+ help="disables flows/cellprob resampling to original image size before computing masks. Using this flag will make more masks more jagged with larger diameter settings.")
139
+ algorithm_args.add_argument(
140
+ "--no_interp", action="store_true",
141
+ help="do not interpolate when running dynamics (was default)")
142
+
143
+ # output settings
144
+ output_args = parser.add_argument_group("Output Arguments")
145
+ output_args.add_argument(
146
+ "--save_png", action="store_true",
147
+ help="save masks as png")
148
+ output_args.add_argument(
149
+ "--save_tif", action="store_true",
150
+ help="save masks as tif")
151
+ output_args.add_argument(
152
+ "--output_name", default=None, type=str,
153
+ help="suffix for saved masks, default is _cp_masks, can be empty if `savedir` used and different of `dir`")
154
+ output_args.add_argument("--no_npy", action="store_true",
155
+ help="suppress saving of npy")
156
+ output_args.add_argument(
157
+ "--savedir", default=None, type=str, help=
158
+ "folder to which segmentation results will be saved (defaults to input image directory)"
159
+ )
160
+ output_args.add_argument(
161
+ "--dir_above", action="store_true", help=
162
+ "save output folders adjacent to image folder instead of inside it (off by default)"
163
+ )
164
+ output_args.add_argument("--in_folders", action="store_true",
165
+ help="flag to save output in folders (off by default)")
166
+ output_args.add_argument(
167
+ "--save_flows", action="store_true", help=
168
+ "whether or not to save RGB images of flows when masks are saved (disabled by default)"
169
+ )
170
+ output_args.add_argument(
171
+ "--save_outlines", action="store_true", help=
172
+ "whether or not to save RGB outline images when masks are saved (disabled by default)"
173
+ )
174
+ output_args.add_argument(
175
+ "--save_rois", action="store_true",
176
+ help="whether or not to save ImageJ compatible ROI archive (disabled by default)"
177
+ )
178
+ output_args.add_argument(
179
+ "--save_txt", action="store_true",
180
+ help="flag to enable txt outlines for ImageJ (disabled by default)")
181
+ output_args.add_argument(
182
+ "--save_mpl", action="store_true",
183
+ help="save a figure of image/mask/flows using matplotlib (disabled by default). "
184
+ "This is slow, especially with large images.")
185
+
186
+ # training settings
187
+ training_args = parser.add_argument_group("Training Arguments")
188
+ training_args.add_argument("--train", action="store_true",
189
+ help="train network using images in dir")
190
+ training_args.add_argument("--test_dir", default=[], type=str,
191
+ help="folder containing test data (optional)")
192
+ training_args.add_argument(
193
+ "--file_list", default=[], type=str, help=
194
+ "path to list of files for training and testing and probabilities for each image (optional)"
195
+ )
196
+ training_args.add_argument(
197
+ "--mask_filter", default="_masks", type=str, help=
198
+ "end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s"
199
+ )
200
+ training_args.add_argument("--learning_rate", default=1e-5, type=float,
201
+ help="learning rate. Default: %(default)s")
202
+ training_args.add_argument("--weight_decay", default=0.1, type=float,
203
+ help="weight decay. Default: %(default)s")
204
+ training_args.add_argument("--n_epochs", default=100, type=int,
205
+ help="number of epochs. Default: %(default)s")
206
+ training_args.add_argument("--train_batch_size", default=1, type=int,
207
+ help="training batch size. Default: %(default)s")
208
+ training_args.add_argument("--bsize", default=256, type=int,
209
+ help="block size for tiles. Default: %(default)s")
210
+ training_args.add_argument(
211
+ "--nimg_per_epoch", default=None, type=int,
212
+ help="number of train images per epoch. Default is to use all train images.")
213
+ training_args.add_argument(
214
+ "--nimg_test_per_epoch", default=None, type=int,
215
+ help="number of test images per epoch. Default is to use all test images.")
216
+ training_args.add_argument(
217
+ "--min_train_masks", default=5, type=int, help=
218
+ "minimum number of masks a training image must have to be used. Default: %(default)s"
219
+ )
220
+ training_args.add_argument("--SGD", default=0, type=int,
221
+ help="Deprecated in v4.0.1+, not used - AdamW used instead. ")
222
+ training_args.add_argument(
223
+ "--save_every", default=100, type=int,
224
+ help="number of epochs to skip between saves. Default: %(default)s")
225
+ training_args.add_argument(
226
+ "--save_each", action="store_true",
227
+ help="wether or not to save each epoch. Must also use --save_every. (default: False)")
228
+ training_args.add_argument(
229
+ "--model_name_out", default=None, type=str,
230
+ help="Name of model to save as, defaults to name describing model architecture. "
231
+ "Model is saved in the folder specified by --dir in models subfolder.")
232
+
233
+ # TODO: remove deprecated in future version
234
+ training_args.add_argument(
235
+ "--diam_mean", default=30., type=float, help=
236
+ 'Deprecated in v4.0.1+, not used. ')
237
+ training_args.add_argument("--train_size", action="store_true", help=
238
+ 'Deprecated in v4.0.1+, not used. ')
239
+
240
+ return parser
models/seg_post_model/cellpose/core.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import logging
5
+ import numpy as np
6
+ from tqdm import trange
7
+ from . import transforms, utils
8
+
9
+ import torch
10
+
11
+ TORCH_ENABLED = True
12
+
13
+ core_logger = logging.getLogger(__name__)
14
+ tqdm_out = utils.TqdmToLogger(core_logger, level=logging.INFO)
15
+
16
+
17
+ def use_gpu(gpu_number=0, use_torch=True):
18
+ """
19
+ Check if GPU is available for use.
20
+
21
+ Args:
22
+ gpu_number (int): The index of the GPU to be used. Default is 0.
23
+ use_torch (bool): Whether to use PyTorch for GPU check. Default is True.
24
+
25
+ Returns:
26
+ bool: True if GPU is available, False otherwise.
27
+
28
+ Raises:
29
+ ValueError: If use_torch is False, as cellpose only runs with PyTorch now.
30
+ """
31
+ if use_torch:
32
+ return _use_gpu_torch(gpu_number)
33
+ else:
34
+ raise ValueError("cellpose only runs with PyTorch now")
35
+
36
+
37
+ def _use_gpu_torch(gpu_number=0):
38
+ """
39
+ Checks if CUDA or MPS is available and working with PyTorch.
40
+
41
+ Args:
42
+ gpu_number (int): The GPU device number to use (default is 0).
43
+
44
+ Returns:
45
+ bool: True if CUDA or MPS is available and working, False otherwise.
46
+ """
47
+ try:
48
+ device = torch.device("cuda:" + str(gpu_number))
49
+ _ = torch.zeros((1,1)).to(device)
50
+ core_logger.info("** TORCH CUDA version installed and working. **")
51
+ return True
52
+ except:
53
+ pass
54
+ try:
55
+ device = torch.device('mps:' + str(gpu_number))
56
+ _ = torch.zeros((1,1)).to(device)
57
+ core_logger.info('** TORCH MPS version installed and working. **')
58
+ return True
59
+ except:
60
+ core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.')
61
+ return False
62
+
63
+
64
+ def assign_device(use_torch=True, gpu=False, device=0):
65
+ """
66
+ Assigns the device (CPU or GPU or mps) to be used for computation.
67
+
68
+ Args:
69
+ use_torch (bool, optional): Whether to use torch for GPU detection. Defaults to True.
70
+ gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
71
+ device (int or str, optional): The device index or name to be used. Defaults to 0.
72
+
73
+ Returns:
74
+ torch.device, bool (True if GPU is used, False otherwise)
75
+ """
76
+
77
+ if isinstance(device, str):
78
+ if device != "mps" or not(gpu and torch.backends.mps.is_available()):
79
+ device = int(device)
80
+ if gpu and use_gpu(use_torch=True):
81
+ try:
82
+ if torch.cuda.is_available():
83
+ device = torch.device(f'cuda:{device}')
84
+ core_logger.info(">>>> using GPU (CUDA)")
85
+ gpu = True
86
+ cpu = False
87
+ except:
88
+ gpu = False
89
+ cpu = True
90
+ try:
91
+ if torch.backends.mps.is_available():
92
+ device = torch.device('mps')
93
+ core_logger.info(">>>> using GPU (MPS)")
94
+ gpu = True
95
+ cpu = False
96
+ except:
97
+ gpu = False
98
+ cpu = True
99
+ else:
100
+ device = torch.device('cpu')
101
+ core_logger.info('>>>> using CPU')
102
+ gpu = False
103
+ cpu = True
104
+
105
+ if cpu:
106
+ device = torch.device("cpu")
107
+ core_logger.info(">>>> using CPU")
108
+ gpu = False
109
+ return device, gpu
110
+
111
+
112
+ def _to_device(x, device, dtype=torch.float32):
113
+ """
114
+ Converts the input tensor or numpy array to the specified device.
115
+
116
+ Args:
117
+ x (torch.Tensor or numpy.ndarray): The input tensor or numpy array.
118
+ device (torch.device): The target device.
119
+
120
+ Returns:
121
+ torch.Tensor: The converted tensor on the specified device.
122
+ """
123
+ if not isinstance(x, torch.Tensor):
124
+ X = torch.from_numpy(x).to(device, dtype=dtype)
125
+ return X
126
+ else:
127
+ return x
128
+
129
+
130
+ def _from_device(X):
131
+ """
132
+ Converts a PyTorch tensor from the device to a NumPy array on the CPU.
133
+
134
+ Args:
135
+ X (torch.Tensor): The input PyTorch tensor.
136
+
137
+ Returns:
138
+ numpy.ndarray: The converted NumPy array.
139
+ """
140
+ # The cast is so numpy conversion always works
141
+ x = X.detach().cpu().to(torch.float32).numpy()
142
+ return x
143
+
144
+
145
+ def _forward(net, x, feat=None):
146
+ """Converts images to torch tensors, runs the network model, and returns numpy arrays.
147
+
148
+ Args:
149
+ net (torch.nn.Module): The network model.
150
+ x (numpy.ndarray): The input images.
151
+
152
+ Returns:
153
+ Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
154
+ """
155
+ X = _to_device(x, device=net.device, dtype=net.dtype)
156
+ if feat is not None:
157
+ feat = _to_device(feat, device=net.device, dtype=net.dtype)
158
+ net.eval()
159
+ with torch.no_grad():
160
+ y, style = net(X, feat=feat)[:2]
161
+ del X
162
+ y = _from_device(y)
163
+ style = _from_device(style)
164
+ return y, style
165
+
166
+
167
+ def run_net(net, imgi, feat=None, batch_size=8, augment=False, tile_overlap=0.1, bsize=224,
168
+ rsz=None):
169
+ """
170
+ Run network on stack of images.
171
+
172
+ (faster if augment is False)
173
+
174
+ Args:
175
+ net (class): cellpose network (model.net)
176
+ imgi (np.ndarray): The input image or stack of images of size [Lz x Ly x Lx x nchan].
177
+ batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
178
+ rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
179
+ augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
180
+ tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
181
+ bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
182
+
183
+ Returns:
184
+ Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
185
+ y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
186
+ style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
187
+ """
188
+ # run network
189
+ Lz, Ly0, Lx0, nchan = imgi.shape
190
+ if rsz is not None:
191
+ if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
192
+ rsz = [rsz, rsz]
193
+ Lyr, Lxr = int(Ly0 * rsz[0]), int(Lx0 * rsz[1])
194
+ else:
195
+ Lyr, Lxr = Ly0, Lx0 # 512, 512
196
+
197
+ ly, lx = bsize, bsize # 256, 256
198
+ ypad1, ypad2, xpad1, xpad2 = transforms.get_pad_yx(Lyr, Lxr, min_size=(bsize, bsize)) # 8
199
+ Ly, Lx = Lyr + ypad1 + ypad2, Lxr + xpad1 + xpad2 # 528, 528
200
+ pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]])
201
+
202
+ if augment:
203
+ ny = max(2, int(np.ceil(2. * Ly / bsize)))
204
+ nx = max(2, int(np.ceil(2. * Lx / bsize)))
205
+ else:
206
+ ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize)) # 3
207
+ nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize)) # 3
208
+
209
+
210
+ # run multiple slices at the same time
211
+ ntiles = ny * nx
212
+ nimgs = max(1, batch_size // ntiles) # number of imgs to run in the same batch, 1
213
+ niter = int(np.ceil(Lz / nimgs)) # 1
214
+ ziterator = (trange(niter, file=tqdm_out, mininterval=30)
215
+ if niter > 10 or Lz > 1 else range(niter))
216
+ for k in ziterator:
217
+ inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs))
218
+ IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 3, 256, 256
219
+ if feat is not None:
220
+ FEATa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 256
221
+ else:
222
+ FEATa = None
223
+ for i, b in enumerate(inds):
224
+ # pad image for net so Ly and Lx are divisible by 4
225
+ imgb = transforms.resize_image(imgi[b], rsz=rsz) if rsz is not None else imgi[b].copy()
226
+ imgb = np.pad(imgb.transpose(2,0,1), pads, mode="constant") # 3, 528, 528
227
+
228
+ IMG, ysub, xsub, Lyt, Lxt = transforms.make_tiles(
229
+ imgb, bsize=bsize, augment=augment,
230
+ tile_overlap=tile_overlap) # IMG: 3, 3, 3, 256, 256
231
+ IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG,
232
+ (ny * nx, nchan, ly, lx))
233
+ if feat is not None:
234
+ featb = transforms.resize_image(feat[b], rsz=rsz) if rsz is not None else feat[b].copy()
235
+ featb = np.pad(featb.transpose(2,0,1), pads, mode="constant")
236
+ FEAT, ysub, xsub, Lyt, Lxt = transforms.make_tiles(
237
+ featb, bsize=bsize, augment=augment,
238
+ tile_overlap=tile_overlap)
239
+ FEATa[i * ntiles : (i+1) * ntiles] = np.reshape(FEAT,
240
+ (ny * nx, nchan, ly, lx))
241
+
242
+ # run network
243
+ for j in range(0, IMGa.shape[0], batch_size):
244
+ bslc = slice(j, min(j + batch_size, IMGa.shape[0]))
245
+ ya0, stylea0 = _forward(net, IMGa[bslc], feat=FEATa[bslc] if FEATa is not None else None)
246
+ if j == 0:
247
+ nout = ya0.shape[1]
248
+ ya = np.zeros((IMGa.shape[0], nout, ly, lx), "float32")
249
+ stylea = np.zeros((IMGa.shape[0], 256), "float32")
250
+ ya[bslc] = ya0
251
+ stylea[bslc] = stylea0
252
+
253
+ # average tiles
254
+ for i, b in enumerate(inds):
255
+ if i==0 and k==0:
256
+ yf = np.zeros((Lz, nout, Ly, Lx), "float32")
257
+ styles = np.zeros((Lz, 256), "float32")
258
+ y = ya[i * ntiles : (i + 1) * ntiles]
259
+ if augment:
260
+ y = np.reshape(y, (ny, nx, 3, ly, lx))
261
+ y = transforms.unaugment_tiles(y)
262
+ y = np.reshape(y, (-1, 3, ly, lx))
263
+ yfi = transforms.average_tiles(y, ysub, xsub, Lyt, Lxt)
264
+ yf[b] = yfi[:, :imgb.shape[-2], :imgb.shape[-1]]
265
+ stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0)
266
+ stylei /= (stylei**2).sum()**0.5
267
+ styles[b] = stylei
268
+ # slices from padding
269
+ yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2]
270
+ yf = yf.transpose(0,2,3,1)
271
+ return yf, np.array(styles)
272
+
273
+
274
+ def run_3D(net, imgs, batch_size=8, augment=False,
275
+ tile_overlap=0.1, bsize=224, net_ortho=None,
276
+ progress=None):
277
+ """
278
+ Run network on image z-stack.
279
+
280
+ (faster if augment is False)
281
+
282
+ Args:
283
+ imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan].
284
+ batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
285
+ rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
286
+ anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
287
+ augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
288
+ tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
289
+ bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
290
+ net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None.
291
+ progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
292
+
293
+ Returns:
294
+ Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
295
+ y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability.
296
+ style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
297
+ """
298
+ sstr = ["YX", "ZY", "ZX"]
299
+ pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)]
300
+ ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)]
301
+ cp = [(1, 2), (0, 2), (0, 1)]
302
+ cpy = [(0, 1), (0, 1), (0, 1)]
303
+ shape = imgs.shape[:-1]
304
+ yf = np.zeros((*shape, 4), "float32")
305
+ for p in range(3):
306
+ xsl = imgs.transpose(pm[p])
307
+ # per image
308
+ core_logger.info("running %s: %d planes of size (%d, %d)" %
309
+ (sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]]))
310
+ y, style = run_net(net,
311
+ xsl, batch_size=batch_size, augment=augment,
312
+ bsize=bsize, tile_overlap=tile_overlap,
313
+ rsz=None)
314
+ yf[..., -1] += y[..., -1].transpose(ipm[p])
315
+ for j in range(2):
316
+ yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p])
317
+ y = None; del y
318
+
319
+ if progress is not None:
320
+ progress.setValue(25 + 15 * p)
321
+
322
+ return yf, style
models/seg_post_model/cellpose/denoise.py ADDED
@@ -0,0 +1,1474 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import os, time, datetime
5
+ import numpy as np
6
+ from scipy.stats import mode
7
+ import cv2
8
+ import torch
9
+ from torch import nn
10
+ from torch.nn.functional import conv2d, interpolate
11
+ from tqdm import trange
12
+ from pathlib import Path
13
+
14
+ import logging
15
+
16
+ denoise_logger = logging.getLogger(__name__)
17
+
18
+ from cellpose import transforms, utils, io
19
+ from cellpose.core import run_net
20
+ from cellpose.models import CellposeModel, model_path, normalize_default, assign_device
21
+
22
+ MODEL_NAMES = []
23
+ for ctype in ["cyto3", "cyto2", "nuclei"]:
24
+ for ntype in ["denoise", "deblur", "upsample", "oneclick"]:
25
+ MODEL_NAMES.append(f"{ntype}_{ctype}")
26
+ if ctype != "cyto3":
27
+ for ltype in ["per", "seg", "rec"]:
28
+ MODEL_NAMES.append(f"{ntype}_{ltype}_{ctype}")
29
+ if ctype != "cyto3":
30
+ MODEL_NAMES.append(f"aniso_{ctype}")
31
+
32
+ criterion = nn.MSELoss(reduction="mean")
33
+ criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
34
+
35
+
36
+ def deterministic(seed=0):
37
+ """ set random seeds to create test data """
38
+ import random
39
+ torch.manual_seed(seed)
40
+ torch.cuda.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
42
+ np.random.seed(seed) # Numpy module.
43
+ random.seed(seed) # Python random module.
44
+ torch.manual_seed(seed)
45
+ torch.backends.cudnn.benchmark = False
46
+ torch.backends.cudnn.deterministic = True
47
+
48
+
49
+ def loss_fn_rec(lbl, y):
50
+ """ loss function between true labels lbl and prediction y """
51
+ loss = 80. * criterion(y, lbl)
52
+ return loss
53
+
54
+
55
+ def loss_fn_seg(lbl, y):
56
+ """ loss function between true labels lbl and prediction y """
57
+ veci = 5. * lbl[:, 1:]
58
+ lbl = (lbl[:, 0] > .5).float()
59
+ loss = criterion(y[:, :2], veci)
60
+ loss /= 2.
61
+ loss2 = criterion2(y[:, 2], lbl)
62
+ loss = loss + loss2
63
+ return loss
64
+
65
+
66
+ def get_sigma(Tdown):
67
+ """ Calculates the correlation matrices across channels for the perceptual loss.
68
+
69
+ Args:
70
+ Tdown (list): List of tensors output by each downsampling block of network.
71
+
72
+ Returns:
73
+ list: List of correlations for each input tensor.
74
+ """
75
+ Tnorm = [x - x.mean((-2, -1), keepdim=True) for x in Tdown]
76
+ Tnorm = [x / x.std((-2, -1), keepdim=True) for x in Tnorm]
77
+ Sigma = [
78
+ torch.einsum("bnxy, bmxy -> bnm", x, x) / (x.shape[-2] * x.shape[-1])
79
+ for x in Tnorm
80
+ ]
81
+ return Sigma
82
+
83
+
84
+ def imstats(X, net1):
85
+ """
86
+ Calculates the image correlation matrices for the perceptual loss.
87
+
88
+ Args:
89
+ X (torch.Tensor): Input image tensor.
90
+ net1: Cellpose net.
91
+
92
+ Returns:
93
+ list: A list of tensors of correlation matrices.
94
+ """
95
+ _, _, Tdown = net1(X)
96
+ Sigma = get_sigma(Tdown)
97
+ Sigma = [x.detach() for x in Sigma]
98
+ return Sigma
99
+
100
+
101
+ def loss_fn_per(img, net1, yl):
102
+ """
103
+ Calculates the perceptual loss function for image restoration.
104
+
105
+ Args:
106
+ img (torch.Tensor): Input image tensor (noisy/blurry/downsampled).
107
+ net1 (torch.nn.Module): Perceptual loss net (Cellpose segmentation net).
108
+ yl (torch.Tensor): Clean image tensor.
109
+
110
+ Returns:
111
+ torch.Tensor: Mean perceptual loss.
112
+ """
113
+ Sigma = imstats(img, net1)
114
+ sd = [x.std((1, 2)) + 1e-6 for x in Sigma]
115
+ Sigma_test = get_sigma(yl)
116
+ losses = torch.zeros(len(Sigma[0]), device=img.device)
117
+ for k in range(len(Sigma)):
118
+ losses = losses + (((Sigma_test[k] - Sigma[k])**2).mean((1, 2)) / sd[k]**2)
119
+ return losses.mean()
120
+
121
+
122
+ def test_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
123
+ """
124
+ Calculates the test loss for image restoration tasks.
125
+
126
+ Args:
127
+ net0 (torch.nn.Module): The image restoration network.
128
+ X (torch.Tensor): The input image tensor.
129
+ net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
130
+ img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
131
+ lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
132
+ lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
133
+
134
+ Returns:
135
+ tuple: A tuple containing the total loss and the perceptual loss.
136
+ """
137
+ net0.eval()
138
+ if net1 is not None:
139
+ net1.eval()
140
+ loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
141
+
142
+ with torch.no_grad():
143
+ img_dn = net0(X)[0]
144
+ if lam[2] > 0.:
145
+ loss += lam[2] * loss_fn_rec(img, img_dn)
146
+ if lam[1] > 0. or lam[0] > 0.:
147
+ y, _, ydown = net1(img_dn)
148
+ if lam[1] > 0.:
149
+ loss += lam[1] * loss_fn_seg(lbl, y)
150
+ if lam[0] > 0.:
151
+ loss_per = loss_fn_per(img, net1, ydown)
152
+ loss += lam[0] * loss_per
153
+ return loss, loss_per
154
+
155
+
156
+ def train_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
157
+ """
158
+ Calculates the train loss for image restoration tasks.
159
+
160
+ Args:
161
+ net0 (torch.nn.Module): The image restoration network.
162
+ X (torch.Tensor): The input image tensor.
163
+ net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
164
+ img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
165
+ lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
166
+ lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
167
+
168
+ Returns:
169
+ tuple: A tuple containing the total loss and the perceptual loss.
170
+ """
171
+ net0.train()
172
+ if net1 is not None:
173
+ net1.eval()
174
+ loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
175
+
176
+ img_dn = net0(X)[0]
177
+ if lam[2] > 0.:
178
+ loss += lam[2] * loss_fn_rec(img, img_dn)
179
+ if lam[1] > 0. or lam[0] > 0.:
180
+ y, _, ydown = net1(img_dn)
181
+ if lam[1] > 0.:
182
+ loss += lam[1] * loss_fn_seg(lbl, y)
183
+ if lam[0] > 0.:
184
+ loss_per = loss_fn_per(img, net1, ydown)
185
+ loss += lam[0] * loss_per
186
+ return loss, loss_per
187
+
188
+
189
+ def img_norm(imgi):
190
+ """
191
+ Normalizes the input image by subtracting the 1st percentile and dividing by the difference between the 99th and 1st percentiles.
192
+
193
+ Args:
194
+ imgi (torch.Tensor): Input image tensor.
195
+
196
+ Returns:
197
+ torch.Tensor: Normalized image tensor.
198
+ """
199
+ shape = imgi.shape
200
+ imgi = imgi.reshape(imgi.shape[0], imgi.shape[1], -1)
201
+ perc = torch.quantile(imgi, torch.tensor([0.01, 0.99], device=imgi.device), dim=-1,
202
+ keepdim=True)
203
+ for k in range(imgi.shape[1]):
204
+ hask = (perc[1, :, k, 0] - perc[0, :, k, 0]) > 1e-3
205
+ imgi[hask, k] -= perc[0, hask, k]
206
+ imgi[hask, k] /= (perc[1, hask, k] - perc[0, hask, k])
207
+ imgi = imgi.reshape(shape)
208
+ return imgi
209
+
210
+
211
+ def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7,
212
+ ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None,
213
+ ds=None, uniform_blur=False, partial_blur=False):
214
+ """Adds noise to the input image.
215
+
216
+ Args:
217
+ lbl (torch.Tensor): The input image tensor of shape (nimg, nchan, Ly, Lx).
218
+ alpha (float, optional): The shape parameter of the gamma distribution used for generating poisson noise. Defaults to 4.
219
+ beta (float, optional): The rate parameter of the gamma distribution used for generating poisson noise. Defaults to 0.7.
220
+ poisson (float, optional): The probability of adding poisson noise to the image. Defaults to 0.7.
221
+ blur (float, optional): The probability of adding gaussian blur to the image. Defaults to 0.7.
222
+ gblur (float, optional): The scale factor for the gaussian blur. Defaults to 1.0.
223
+ downsample (float, optional): The probability of downsampling the image. Defaults to 0.7.
224
+ ds_max (int, optional): The maximum downsampling factor. Defaults to 7.
225
+ diams (torch.Tensor, optional): The diameter of the objects in the image. Defaults to None.
226
+ pscale (torch.Tensor, optional): The scale factor for the poisson noise, instead of sampling. Defaults to None.
227
+ iso (bool, optional): Whether to use isotropic gaussian blur. Defaults to True.
228
+ sigma0 (torch.Tensor, optional): The standard deviation of the gaussian filter for the Y axis, instead of sampling. Defaults to None.
229
+ sigma1 (torch.Tensor, optional): The standard deviation of the gaussian filter for the X axis, instead of sampling. Defaults to None.
230
+ ds (torch.Tensor, optional): The downsampling factor for each image, instead of sampling. Defaults to None.
231
+
232
+ Returns:
233
+ torch.Tensor: The noisy image tensor of the same shape as the input image.
234
+ """
235
+ device = lbl.device
236
+ imgi = torch.zeros_like(lbl)
237
+ Ly, Lx = lbl.shape[-2:]
238
+
239
+ diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device)
240
+ #ds0 = 1 if ds is None else ds.item()
241
+ ds = ds * torch.ones(
242
+ (len(lbl),), device=device, dtype=torch.long) if ds is not None else ds
243
+
244
+ # downsample
245
+ ii = []
246
+ idownsample = np.random.rand(len(lbl)) < downsample
247
+ if (ds is None and idownsample.sum() > 0.) or not iso:
248
+ ds = torch.ones(len(lbl), dtype=torch.long, device=device)
249
+ ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),),
250
+ device=device)
251
+ ii = torch.nonzero(ds > 1).flatten()
252
+ elif ds is not None and (ds > 1).sum():
253
+ ii = torch.nonzero(ds > 1).flatten()
254
+
255
+ # add gaussian blur
256
+ iblur = torch.rand(len(lbl), device=device) < blur
257
+ iblur[ii] = True
258
+ if iblur.sum() > 0:
259
+ if sigma0 is None:
260
+ if uniform_blur and iso:
261
+ xr = torch.rand(len(lbl), device=device)
262
+ if len(ii) > 0:
263
+ xr[ii] = ds[ii].float() / 2. / gblur
264
+ sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur])
265
+ sigma1 = sigma0.clone()
266
+ elif not iso:
267
+ xr = torch.rand(len(lbl), device=device)
268
+ if len(ii) > 0:
269
+ xr[ii] = (ds[ii].float()) / gblur
270
+ xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35
271
+ xr[ii] = torch.clip(xr[ii], 0.05, 1.5)
272
+ sigma0 = diams[iblur] / 30. * gblur * xr[iblur]
273
+ sigma1 = sigma0.clone() / 10.
274
+ else:
275
+ xrand = np.random.exponential(1, size=iblur.sum())
276
+ xrand = np.clip(xrand * 0.5, 0.1, 1.0)
277
+ xrand *= gblur
278
+ sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to(
279
+ device)
280
+ sigma1 = sigma0.clone()
281
+ else:
282
+ sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device)
283
+ sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device)
284
+
285
+ # create gaussian filter
286
+ xr = max(8, sigma0.max().long() * 2)
287
+ gfilt0 = torch.exp(-torch.arange(-xr + 1, xr, device=device)**2 /
288
+ (2 * sigma0.unsqueeze(-1)**2))
289
+ gfilt0 /= gfilt0.sum(axis=-1, keepdims=True)
290
+ gfilt1 = torch.zeros_like(gfilt0)
291
+ gfilt1[sigma1 == sigma0] = gfilt0[sigma1 == sigma0]
292
+ gfilt1[sigma1 != sigma0] = torch.exp(
293
+ -torch.arange(-xr + 1, xr, device=device)**2 /
294
+ (2 * sigma1[sigma1 != sigma0].unsqueeze(-1)**2))
295
+ gfilt1[sigma1 == 0] = 0.
296
+ gfilt1[sigma1 == 0, xr] = 1.
297
+ gfilt1 /= gfilt1.sum(axis=-1, keepdims=True)
298
+ gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1)
299
+ gfilt /= gfilt.sum(axis=(1, 2), keepdims=True)
300
+
301
+ lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1),
302
+ padding=gfilt.shape[-1] // 2,
303
+ groups=gfilt.shape[0]).transpose(1, 0)
304
+ if partial_blur:
305
+ #yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100)
306
+ imgi[iblur] = lbl[iblur].clone()
307
+ Lxc = int(Lx * 0.85)
308
+ ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32),
309
+ torch.arange(0, Lxc, dtype=torch.float32),
310
+ indexing="ij")
311
+ mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2))
312
+ mask -= mask.min()
313
+ mask /= mask.max()
314
+ lbl_blur_crop = lbl_blur[:, :, :, :Lxc]
315
+ imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask +
316
+ (1-mask) * imgi[iblur, :, :, :Lxc])
317
+ else:
318
+ imgi[iblur] = lbl_blur
319
+
320
+ imgi[~iblur] = lbl[~iblur]
321
+
322
+ # apply downsample
323
+ for k in ii:
324
+ i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]]
325
+ imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear")
326
+
327
+ # add poisson noise
328
+ ipoisson = np.random.rand(len(lbl)) < poisson
329
+ if ipoisson.sum() > 0:
330
+ if pscale is None:
331
+ pscale = torch.zeros(len(lbl))
332
+ m = torch.distributions.gamma.Gamma(alpha, beta)
333
+ pscale = torch.clamp(m.rsample(sample_shape=(ipoisson.sum(),)), 1.)
334
+ #pscale = torch.clamp(20 * (torch.rand(size=(len(lbl),), device=lbl.device)), 1.5)
335
+ pscale = pscale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
336
+ else:
337
+ pscale = pscale * torch.ones((ipoisson.sum(), 1, 1, 1), device=device)
338
+ imgi[ipoisson] = torch.poisson(pscale * imgi[ipoisson])
339
+ imgi[~ipoisson] = imgi[~ipoisson]
340
+
341
+ # renormalize
342
+ imgi = img_norm(imgi)
343
+
344
+ return imgi
345
+
346
+
347
+ def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7,
348
+ downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30,
349
+ ds_max=7, uniform_blur=False, iso=True, rotate=True,
350
+ device=torch.device("cuda"), xy=(224, 224),
351
+ nchan_noise=1, keep_raw=True):
352
+ """
353
+ Applies random rotation, resizing, and noise to the input data.
354
+
355
+ Args:
356
+ data (numpy.ndarray): The input data.
357
+ labels (numpy.ndarray, optional): The flow and cellprob labels associated with the data. Defaults to None.
358
+ diams (float, optional): The diameter of the objects. Defaults to None.
359
+ poisson (float, optional): The Poisson noise probability. Defaults to 0.7.
360
+ blur (float, optional): The blur probability. Defaults to 0.7.
361
+ downsample (float, optional): The downsample probability. Defaults to 0.0.
362
+ beta (float, optional): The beta value for the poisson noise distribution. Defaults to 0.7.
363
+ gblur (float, optional): The Gaussian blur level. Defaults to 1.0.
364
+ diam_mean (float, optional): The mean diameter. Defaults to 30.
365
+ ds_max (int, optional): The maximum downsample value. Defaults to 7.
366
+ iso (bool, optional): Whether to apply isotropic augmentation. Defaults to True.
367
+ rotate (bool, optional): Whether to apply rotation augmentation. Defaults to True.
368
+ device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
369
+ xy (tuple, optional): The size of the output image. Defaults to (224, 224).
370
+ nchan_noise (int, optional): The number of channels to add noise to. Defaults to 1.
371
+ keep_raw (bool, optional): Whether to keep the raw image. Defaults to True.
372
+
373
+ Returns:
374
+ torch.Tensor: The augmented image and augmented noisy/blurry/downsampled version of image.
375
+ torch.Tensor: The augmented labels.
376
+ float: The scale factor applied to the image.
377
+ """
378
+ if device == None:
379
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
380
+
381
+ diams = 30 if diams is None else diams
382
+ random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1))
383
+ random_rsc = diams / random_diam #/ random_diam
384
+ #rsc /= random_scale
385
+ xy0 = (340, 340)
386
+ nchan = data[0].shape[0]
387
+ data_new = np.zeros((len(data), (1 + keep_raw) * nchan, xy0[0], xy0[1]), "float32")
388
+ labels_new = np.zeros((len(data), 3, xy0[0], xy0[1]), "float32")
389
+ for i in range(
390
+ len(data)): #, (sc, img, lbl) in enumerate(zip(random_rsc, data, labels)):
391
+ sc = random_rsc[i]
392
+ img = data[i]
393
+ lbl = labels[i] if labels is not None else None
394
+ # create affine transform to resize
395
+ Ly, Lx = img.shape[-2:]
396
+ dxy = np.maximum(0, np.array([Lx / sc - xy0[1], Ly / sc - xy0[0]]))
397
+ dxy = (np.random.rand(2,) - .5) * dxy
398
+ cc = np.array([Lx / 2, Ly / 2])
399
+ cc1 = cc - np.array([Lx - xy0[1], Ly - xy0[0]]) / 2 + dxy
400
+ pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
401
+ pts2 = np.float32(
402
+ [cc1, cc1 + np.array([1, 0]) / sc, cc1 + np.array([0, 1]) / sc])
403
+ M = cv2.getAffineTransform(pts1, pts2)
404
+
405
+ # apply to image
406
+ for c in range(nchan):
407
+ img_rsz = cv2.warpAffine(img[c], M, xy0, flags=cv2.INTER_LINEAR)
408
+ #img_noise = add_noise(torch.from_numpy(img_rsz).to(device).unsqueeze(0)).cpu().numpy().squeeze(0)
409
+ data_new[i, c] = img_rsz
410
+ if keep_raw:
411
+ data_new[i, c + nchan] = img_rsz
412
+
413
+ if lbl is not None:
414
+ # apply to labels
415
+ labels_new[i, 0] = cv2.warpAffine(lbl[0], M, xy0, flags=cv2.INTER_NEAREST)
416
+ labels_new[i, 1] = cv2.warpAffine(lbl[1], M, xy0, flags=cv2.INTER_LINEAR)
417
+ labels_new[i, 2] = cv2.warpAffine(lbl[2], M, xy0, flags=cv2.INTER_LINEAR)
418
+
419
+ rsc = random_diam / diam_mean
420
+
421
+ # add noise before augmentations
422
+ img = torch.from_numpy(data_new).to(device)
423
+ img = torch.clamp(img, 0.)
424
+ # just add noise to cyto if nchan_noise=1
425
+ img[:, :nchan_noise] = add_noise(
426
+ img[:, :nchan_noise], poisson=poisson, blur=blur, ds_max=ds_max, iso=iso,
427
+ downsample=downsample, beta=beta, gblur=gblur,
428
+ diams=torch.from_numpy(random_diam).to(device).float())
429
+ # img -= img.mean(dim=(-2,-1), keepdim=True)
430
+ # img /= img.std(dim=(-2,-1), keepdim=True) + 1e-3
431
+ img = img.cpu().numpy()
432
+
433
+ # augmentations
434
+ img, lbl, scale = transforms.random_rotate_and_resize(
435
+ img,
436
+ Y=labels_new,
437
+ xy=xy,
438
+ rotate=False if not iso else rotate,
439
+ #(iso and downsample==0),
440
+ rescale=rsc,
441
+ scale_range=0.5)
442
+ img = torch.from_numpy(img).to(device)
443
+ lbl = torch.from_numpy(lbl).to(device)
444
+
445
+ return img, lbl, scale
446
+
447
+
448
+ def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
449
+ """
450
+ Creates a Cellpose network with a single input channel.
451
+
452
+ Args:
453
+ device (str): The device to run the network on.
454
+ model_type (str, optional): The type of Cellpose model to use. Defaults to "cyto2".
455
+ pretrained_model (str, optional): The path to a pretrained model file. Defaults to None.
456
+
457
+ Returns:
458
+ torch.nn.Module: The Cellpose network with a single input channel.
459
+ """
460
+ if pretrained_model is not None and not os.path.exists(pretrained_model):
461
+ model_type = pretrained_model
462
+ pretrained_model = None
463
+ nbase = [32, 64, 128, 256]
464
+ nchan = 1
465
+ net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device)
466
+ filename = model_path(model_type,
467
+ 0) if pretrained_model is None else pretrained_model
468
+ weights = torch.load(filename, weights_only=True)
469
+ zp = 0
470
+ print(filename)
471
+ for name in net1.state_dict():
472
+ if ("res_down_0.conv.conv_0" not in name and
473
+ #"output" not in name and
474
+ "res_down_0.proj" not in name and name != "diam_mean" and
475
+ name != "diam_labels"):
476
+ net1.state_dict()[name].copy_(weights[name])
477
+ elif "res_down_0" in name:
478
+ if len(weights[name].shape) > 0:
479
+ new_weight = torch.zeros_like(net1.state_dict()[name])
480
+ if weights[name].shape[0] == 2:
481
+ new_weight[:] = weights[name][0]
482
+ elif len(weights[name].shape) > 1 and weights[name].shape[1] == 2:
483
+ new_weight[:, zp] = weights[name][:, 0]
484
+ else:
485
+ new_weight = weights[name]
486
+ else:
487
+ new_weight = weights[name]
488
+ net1.state_dict()[name].copy_(new_weight)
489
+ return net1
490
+
491
+
492
+ class CellposeDenoiseModel():
493
+ """ model to run Cellpose and Image restoration """
494
+
495
+ def __init__(self, gpu=False, pretrained_model=False, model_type=None,
496
+ restore_type="denoise_cyto3", nchan=2,
497
+ chan2_restore=False, device=None):
498
+
499
+ self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore,
500
+ device=device)
501
+ self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan,
502
+ pretrained_model=pretrained_model, device=device)
503
+
504
+ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
505
+ normalize=True, rescale=None, diameter=None, tile_overlap=0.1,
506
+ augment=False, resample=True, invert=False, flow_threshold=0.4,
507
+ cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0,
508
+ min_size=15, niter=None, interp=True, bsize=224, flow3D_smooth=0):
509
+ """
510
+ Restore array or list of images using the image restoration model, and then segment.
511
+
512
+ Args:
513
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
514
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
515
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
516
+ channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
517
+ First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
518
+ Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
519
+ For instance, to segment grayscale images, input [0,0]. To segment images with cells
520
+ in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
521
+ image with cells in green and nuclei in blue, input [[0,0], [2,3]].
522
+ Defaults to None.
523
+ channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
524
+ if None, channels dimension is attempted to be automatically determined. Defaults to None.
525
+ z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
526
+ if None, z dimension is attempted to be automatically determined. Defaults to None.
527
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
528
+ can also pass dictionary of parameters (all keys are optional, default values shown):
529
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
530
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
531
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
532
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
533
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
534
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
535
+ Defaults to True.
536
+ rescale (float, optional): resize factor for each image, if None, set to 1.0;
537
+ (only used if diameter is None). Defaults to None.
538
+ diameter (float, optional): diameter for each image,
539
+ if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
540
+ tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
541
+ augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
542
+ resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
543
+ invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
544
+ flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
545
+ cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
546
+ do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False.
547
+ anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
548
+ stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
549
+ min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
550
+ flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0.
551
+ niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None.
552
+ interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
553
+
554
+ Returns:
555
+ A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels;
556
+ flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration;
557
+ styles: style vector summarizing each image of size 256;
558
+ imgs: Restored images.
559
+ """
560
+
561
+ if isinstance(normalize, dict):
562
+ normalize_params = {**normalize_default, **normalize}
563
+ elif not isinstance(normalize, bool):
564
+ raise ValueError("normalize parameter must be a bool or a dict")
565
+ else:
566
+ normalize_params = normalize_default
567
+ normalize_params["normalize"] = normalize
568
+ normalize_params["invert"] = invert
569
+
570
+ img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels,
571
+ channel_axis=channel_axis, z_axis=z_axis,
572
+ do_3D=do_3D,
573
+ normalize=normalize_params, rescale=rescale,
574
+ diameter=diameter,
575
+ tile_overlap=tile_overlap, bsize=bsize)
576
+
577
+ # turn off special normalization for segmentation
578
+ normalize_params = normalize_default
579
+
580
+ # change channels for segmentation
581
+ if channels is not None:
582
+ channels_new = [0, 0] if channels[0] == 0 else [1, 2]
583
+ else:
584
+ channels_new = None
585
+ # change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean)
586
+ diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
587
+ masks, flows, styles = self.cp.eval(
588
+ img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1,
589
+ z_axis=0 if not isinstance(img_restore, list) and img_restore.ndim > 3 and img_restore.shape[0] > 0 else None,
590
+ normalize=normalize_params, rescale=rescale, diameter=diameter,
591
+ tile_overlap=tile_overlap, augment=augment, resample=resample,
592
+ invert=invert, flow_threshold=flow_threshold,
593
+ cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy,
594
+ stitch_threshold=stitch_threshold, min_size=min_size, niter=niter,
595
+ interp=interp, bsize=bsize)
596
+
597
+ return masks, flows, styles, img_restore
598
+
599
+
600
+ class DenoiseModel():
601
+ """
602
+ DenoiseModel class for denoising images using Cellpose denoising model.
603
+
604
+ Args:
605
+ gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
606
+ pretrained_model (bool or str or Path, optional): Pretrained model to use for denoising.
607
+ Can be a string or path. Defaults to False.
608
+ nchan (int, optional): Number of channels in the input images, all Cellpose 3 models were trained with nchan=1. Defaults to 1.
609
+ model_type (str, optional): Type of pretrained model to use ("denoise_cyto3", "deblur_cyto3", "upsample_cyto3", ...). Defaults to None.
610
+ chan2 (bool, optional): Whether to use a separate model for the second channel. Defaults to False.
611
+ diam_mean (float, optional): Mean diameter of the objects in the images. Defaults to 30.0.
612
+ device (torch.device, optional): Device to use for computation. Defaults to None.
613
+
614
+ Attributes:
615
+ nchan (int): Number of channels in the input images.
616
+ diam_mean (float): Mean diameter of the objects in the images.
617
+ net (CPnet): Cellpose network for denoising.
618
+ pretrained_model (bool or str or Path): Pretrained model path to use for denoising.
619
+ net_chan2 (CPnet or None): Cellpose network for the second channel, if applicable.
620
+ net_type (str): Type of the denoising network.
621
+
622
+ Methods:
623
+ eval(x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
624
+ normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1)
625
+ Denoise array or list of images using the denoising model.
626
+
627
+ _eval(net, x, normalize=True, rescale=None, diameter=None, tile=True,
628
+ tile_overlap=0.1)
629
+ Run denoising model on a single channel.
630
+ """
631
+
632
+ def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
633
+ chan2=False, diam_mean=30., device=None):
634
+ self.nchan = nchan
635
+ if pretrained_model and (not isinstance(pretrained_model, str) and
636
+ not isinstance(pretrained_model, Path)):
637
+ raise ValueError("pretrained_model must be a string or path")
638
+
639
+ self.diam_mean = diam_mean
640
+ builtin = True
641
+ if model_type is not None or (pretrained_model and
642
+ not os.path.exists(pretrained_model)):
643
+ pretrained_model_string = model_type if model_type is not None else "denoise_cyto3"
644
+ if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]):
645
+ pretrained_model_string = "denoise_cyto3"
646
+ pretrained_model = model_path(pretrained_model_string)
647
+ if (pretrained_model and not os.path.exists(pretrained_model)):
648
+ denoise_logger.warning("pretrained model has incorrect path")
649
+ denoise_logger.info(f">> {pretrained_model_string} << model set to be used")
650
+ self.diam_mean = 17. if "nuclei" in pretrained_model_string else 30.
651
+ else:
652
+ if pretrained_model:
653
+ builtin = False
654
+ pretrained_model_string = pretrained_model
655
+ denoise_logger.info(f">>>> loading model {pretrained_model_string}")
656
+
657
+ # assign network device
658
+ if device is None:
659
+ sdevice, gpu = assign_device(use_torch=True, gpu=gpu)
660
+ self.device = device if device is not None else sdevice
661
+ if device is not None:
662
+ device_gpu = self.device.type == "cuda"
663
+ self.gpu = gpu if device is None else device_gpu
664
+
665
+ # create network
666
+ self.nchan = nchan
667
+ self.nclasses = 1
668
+ nbase = [32, 64, 128, 256]
669
+ self.nchan = nchan
670
+ self.nbase = [nchan, *nbase]
671
+
672
+ self.net = CPnet(self.nbase, self.nclasses, sz=3,
673
+ max_pool=True, diam_mean=diam_mean).to(self.device)
674
+
675
+ self.pretrained_model = pretrained_model
676
+ self.net_chan2 = None
677
+ if self.pretrained_model:
678
+ self.net.load_model(self.pretrained_model, device=self.device)
679
+ denoise_logger.info(
680
+ f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)"
681
+ )
682
+ if chan2 and builtin:
683
+ chan2_path = model_path(
684
+ os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei")
685
+ print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}")
686
+ self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3,
687
+ max_pool=True,
688
+ diam_mean=17.).to(self.device)
689
+ self.net_chan2.load_model(chan2_path, device=self.device)
690
+ self.net_type = "cellpose_denoise"
691
+
692
+ def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
693
+ normalize=True, rescale=None, diameter=None, tile=True, do_3D=False,
694
+ tile_overlap=0.1, bsize=224):
695
+ """
696
+ Restore array or list of images using the image restoration model.
697
+
698
+ Args:
699
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
700
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
701
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
702
+ channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
703
+ First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
704
+ Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
705
+ For instance, to segment grayscale images, input [0,0]. To segment images with cells
706
+ in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
707
+ image with cells in green and nuclei in blue, input [[0,0], [2,3]].
708
+ Defaults to None.
709
+ channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
710
+ if None, channels dimension is attempted to be automatically determined. Defaults to None.
711
+ z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
712
+ if None, z dimension is attempted to be automatically determined. Defaults to None.
713
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
714
+ can also pass dictionary of parameters (all keys are optional, default values shown):
715
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
716
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
717
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
718
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
719
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
720
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
721
+ Defaults to True.
722
+ rescale (float, optional): resize factor for each image, if None, set to 1.0;
723
+ (only used if diameter is None). Defaults to None.
724
+ diameter (float, optional): diameter for each image,
725
+ if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
726
+ tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
727
+
728
+ Returns:
729
+ list: A list of 2D/3D arrays of restored images
730
+
731
+ """
732
+ if isinstance(x, list) or x.squeeze().ndim == 5:
733
+ tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO)
734
+ nimg = len(x)
735
+ iterator = trange(nimg, file=tqdm_out,
736
+ mininterval=30) if nimg > 1 else range(nimg)
737
+ imgs = []
738
+ for i in iterator:
739
+ imgi = self.eval(
740
+ x[i], batch_size=batch_size,
741
+ channels=channels[i] if channels is not None and
742
+ ((len(channels) == len(x) and
743
+ (isinstance(channels[i], list) or
744
+ isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2))
745
+ else channels, channel_axis=channel_axis, z_axis=z_axis,
746
+ normalize=normalize,
747
+ do_3D=do_3D,
748
+ rescale=rescale[i] if isinstance(rescale, list) or
749
+ isinstance(rescale, np.ndarray) else rescale,
750
+ diameter=diameter[i] if isinstance(diameter, list) or
751
+ isinstance(diameter, np.ndarray) else diameter,
752
+ tile_overlap=tile_overlap, bsize=bsize)
753
+ imgs.append(imgi)
754
+ if isinstance(x, np.ndarray):
755
+ imgs = np.array(imgs)
756
+ return imgs
757
+
758
+ else:
759
+ # reshape image
760
+ x = transforms.convert_image(x, channels, channel_axis=channel_axis,
761
+ z_axis=z_axis, do_3D=do_3D, nchan=None)
762
+ if x.ndim < 4:
763
+ squeeze = True
764
+ x = x[np.newaxis, ...]
765
+ else:
766
+ squeeze = False
767
+
768
+ # may need to interpolate image before running upsampling
769
+ self.ratio = 1.
770
+ if "upsample" in self.pretrained_model:
771
+ Ly, Lx = x.shape[-3:-1]
772
+ if diameter is not None and 3 <= diameter < self.diam_mean:
773
+ self.ratio = self.diam_mean / diameter
774
+ denoise_logger.info(
775
+ f"upsampling image to {self.diam_mean} pixel diameter ({self.ratio:0.2f} times)"
776
+ )
777
+ Lyr, Lxr = int(Ly * self.ratio), int(Lx * self.ratio)
778
+ x = transforms.resize_image(x, Ly=Lyr, Lx=Lxr)
779
+ else:
780
+ denoise_logger.warning(
781
+ f"not interpolating image before upsampling because diameter is set >= {self.diam_mean}"
782
+ )
783
+ #raise ValueError(f"diameter is set to {diameter}, needs to be >=3 and < {self.dn.diam_mean}")
784
+
785
+ self.batch_size = batch_size
786
+
787
+ if diameter is not None and diameter > 0:
788
+ rescale = self.diam_mean / diameter
789
+ elif rescale is None:
790
+ rescale = 1.0
791
+
792
+ if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0):
793
+ x = x[..., :1]
794
+
795
+ for c in range(x.shape[-1]):
796
+ rescale0 = rescale * 30. / 17. if c == 1 else rescale
797
+ if c == 0 or self.net_chan2 is None:
798
+ x[...,
799
+ c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size,
800
+ normalize=normalize, rescale=rescale0,
801
+ tile_overlap=tile_overlap, bsize=bsize)[...,0]
802
+ else:
803
+ x[...,
804
+ c] = self._eval(self.net_chan2, x[...,
805
+ c:c + 1], batch_size=batch_size,
806
+ normalize=normalize, rescale=rescale0,
807
+ tile_overlap=tile_overlap, bsize=bsize)[...,0]
808
+ x = x[0] if squeeze else x
809
+ return x
810
+
811
+ def _eval(self, net, x, batch_size=8, normalize=True, rescale=None,
812
+ tile_overlap=0.1, bsize=224):
813
+ """
814
+ Run image restoration model on a single channel.
815
+
816
+ Args:
817
+ x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
818
+ batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
819
+ (can make smaller or bigger depending on GPU memory usage). Defaults to 8.
820
+ normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
821
+ can also pass dictionary of parameters (all keys are optional, default values shown):
822
+ - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
823
+ - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
824
+ - "normalize"=True ; run normalization (if False, all following parameters ignored)
825
+ - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
826
+ - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
827
+ - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
828
+ Defaults to True.
829
+ rescale (float, optional): resize factor for each image, if None, set to 1.0;
830
+ (only used if diameter is None). Defaults to None.
831
+ tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
832
+
833
+ Returns:
834
+ list: A list of 2D/3D arrays of restored images
835
+
836
+ """
837
+ if isinstance(normalize, dict):
838
+ normalize_params = {**normalize_default, **normalize}
839
+ elif not isinstance(normalize, bool):
840
+ raise ValueError("normalize parameter must be a bool or a dict")
841
+ else:
842
+ normalize_params = normalize_default
843
+ normalize_params["normalize"] = normalize
844
+
845
+ tic = time.time()
846
+ shape = x.shape
847
+ nimg = shape[0]
848
+
849
+ do_normalization = True if normalize_params["normalize"] else False
850
+
851
+ img = np.asarray(x)
852
+ if do_normalization:
853
+ img = transforms.normalize_img(img, **normalize_params)
854
+ if rescale != 1.0:
855
+ img = transforms.resize_image(img, rsz=rescale)
856
+ yf, style = run_net(self.net, img, bsize=bsize,
857
+ tile_overlap=tile_overlap)
858
+ yf = transforms.resize_image(yf, shape[1], shape[2])
859
+ imgs = yf
860
+ del yf, style
861
+
862
+ # imgs = np.zeros((*x.shape[:-1], 1), np.float32)
863
+ # for i in iterator:
864
+ # img = np.asarray(x[i])
865
+ # if do_normalization:
866
+ # img = transforms.normalize_img(img, **normalize_params)
867
+ # if rescale != 1.0:
868
+ # img = transforms.resize_image(img, rsz=[rescale, rescale])
869
+ # if img.ndim == 2:
870
+ # img = img[:, :, np.newaxis]
871
+ # yf, style = run_net(net, img, batch_size=batch_size, augment=False,
872
+ # tile=tile, tile_overlap=tile_overlap, bsize=bsize)
873
+ # img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2])
874
+
875
+ # if img.ndim == 2:
876
+ # img = img[:, :, np.newaxis]
877
+ # imgs[i] = img
878
+ # del yf, style
879
+ net_time = time.time() - tic
880
+ if nimg > 1:
881
+ denoise_logger.info("imgs denoised in %2.2fs" % (net_time))
882
+
883
+ return imgs
884
+
885
+
886
+ def train(net, train_data=None, train_labels=None, train_files=None, test_data=None,
887
+ test_labels=None, test_files=None, train_probs=None, test_probs=None,
888
+ lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None,
889
+ save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0,
890
+ iso=True, uniform_blur=False, downsample=0., ds_max=7,
891
+ learning_rate=0.005, n_epochs=500,
892
+ weight_decay=0.00001, batch_size=8, nimg_per_epoch=None,
893
+ nimg_test_per_epoch=None, model_name=None):
894
+
895
+ # net properties
896
+ device = net.device
897
+ nchan = net.nchan
898
+ diam_mean = net.diam_mean.item()
899
+
900
+ args = np.array([poisson, beta, blur, gblur, downsample])
901
+ if args.ndim == 1:
902
+ args = args[:, np.newaxis]
903
+ poisson, beta, blur, gblur, downsample = args
904
+ nnoise = len(poisson)
905
+
906
+ d = datetime.datetime.now()
907
+ if save_path is not None:
908
+ if model_name is None:
909
+ filename = ""
910
+ lstrs = ["per", "seg", "rec"]
911
+ for k, (l, s) in enumerate(zip(lam, lstrs)):
912
+ filename += f"{s}_{l:.2f}_"
913
+ if not iso:
914
+ filename += "aniso_"
915
+ if poisson.sum() > 0:
916
+ filename += "poisson_"
917
+ if blur.sum() > 0:
918
+ filename += "blur_"
919
+ if downsample.sum() > 0:
920
+ filename += "downsample_"
921
+ filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f")
922
+ filename = os.path.join(save_path, filename)
923
+ else:
924
+ filename = os.path.join(save_path, model_name)
925
+ print(filename)
926
+ for i in range(len(poisson)):
927
+ denoise_logger.info(
928
+ f"poisson: {poisson[i]: 0.2f}, beta: {beta[i]: 0.2f}, blur: {blur[i]: 0.2f}, gblur: {gblur[i]: 0.2f}, downsample: {downsample[i]: 0.2f}"
929
+ )
930
+ net1 = one_chan_cellpose(device=device, pretrained_model=seg_model_type)
931
+
932
+ learning_rate_const = learning_rate
933
+ LR = np.linspace(0, learning_rate_const, 10)
934
+ LR = np.append(LR, learning_rate_const * np.ones(n_epochs - 100))
935
+ for i in range(10):
936
+ LR = np.append(LR, LR[-1] / 2 * np.ones(10))
937
+ learning_rate = LR
938
+
939
+ batch_size = 8
940
+ optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate[0],
941
+ weight_decay=weight_decay)
942
+ if train_data is not None:
943
+ nimg = len(train_data)
944
+ diam_train = np.array(
945
+ [utils.diameters(train_labels[k])[0] for k in trange(len(train_labels))])
946
+ diam_train[diam_train < 5] = 5.
947
+ if test_data is not None:
948
+ diam_test = np.array(
949
+ [utils.diameters(test_labels[k])[0] for k in trange(len(test_labels))])
950
+ diam_test[diam_test < 5] = 5.
951
+ nimg_test = len(test_data)
952
+ else:
953
+ nimg = len(train_files)
954
+ denoise_logger.info(">>> using files instead of loading dataset")
955
+ train_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in train_files]
956
+ denoise_logger.info(">>> computing diameters")
957
+ diam_train = np.array([
958
+ utils.diameters(io.imread(train_labels_files[k])[0])[0]
959
+ for k in trange(len(train_labels_files))
960
+ ])
961
+ diam_train[diam_train < 5] = 5.
962
+ if test_files is not None:
963
+ nimg_test = len(test_files)
964
+ test_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in test_files]
965
+ diam_test = np.array([
966
+ utils.diameters(io.imread(test_labels_files[k])[0])[0]
967
+ for k in trange(len(test_labels_files))
968
+ ])
969
+ diam_test[diam_test < 5] = 5.
970
+ train_probs = 1. / nimg * np.ones(nimg,
971
+ "float64") if train_probs is None else train_probs
972
+ if test_files is not None or test_data is not None:
973
+ test_probs = 1. / nimg_test * np.ones(
974
+ nimg_test, "float64") if test_probs is None else test_probs
975
+
976
+ tic = time.time()
977
+
978
+ nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
979
+ if test_files is not None or test_data is not None:
980
+ nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
981
+
982
+ nbatch = 0
983
+ train_losses, test_losses = [], []
984
+ for iepoch in range(n_epochs):
985
+ np.random.seed(iepoch)
986
+ rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
987
+ p=train_probs)
988
+ torch.manual_seed(iepoch)
989
+ np.random.seed(iepoch)
990
+ for param_group in optimizer.param_groups:
991
+ param_group["lr"] = learning_rate[iepoch]
992
+ lavg, lavg_per, nsum = 0, 0, 0
993
+ for ibatch in range(0, nimg_per_epoch, batch_size * nnoise):
994
+ inds = rperm[ibatch : ibatch + batch_size * nnoise]
995
+ if train_data is None:
996
+ imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds]
997
+ lbls = [io.imread(train_labels_files[i])[1:] for i in inds]
998
+ else:
999
+ imgs = [train_data[i][:nchan] for i in inds]
1000
+ lbls = [train_labels[i][1:] for i in inds]
1001
+ #inoise = nbatch % nnoise
1002
+ rnoise = np.random.permutation(nnoise)
1003
+ for i, inoise in enumerate(rnoise):
1004
+ if i * batch_size < len(imgs):
1005
+ imgi, lbli, scale = random_rotate_and_resize_noise(
1006
+ imgs[i * batch_size : (i + 1) * batch_size],
1007
+ lbls[i * batch_size : (i + 1) * batch_size],
1008
+ diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(),
1009
+ poisson=poisson[inoise],
1010
+ beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso,
1011
+ downsample=downsample[inoise], uniform_blur=uniform_blur,
1012
+ diam_mean=diam_mean, ds_max=ds_max,
1013
+ device=device)
1014
+ if i == 0:
1015
+ img = imgi
1016
+ lbl = lbli
1017
+ else:
1018
+ img = torch.cat((img, imgi), axis=0)
1019
+ lbl = torch.cat((lbl, lbli), axis=0)
1020
+
1021
+ if nnoise > 0:
1022
+ iperm = np.random.permutation(img.shape[0])
1023
+ img, lbl = img[iperm], lbl[iperm]
1024
+
1025
+ for i in range(nnoise):
1026
+ optimizer.zero_grad()
1027
+ imgi = img[i * batch_size: (i + 1) * batch_size]
1028
+ lbli = lbl[i * batch_size: (i + 1) * batch_size]
1029
+ if imgi.shape[0] > 0:
1030
+ loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1,
1031
+ img=imgi[:, nchan:], lbl=lbli, lam=lam)
1032
+ loss.backward()
1033
+ optimizer.step()
1034
+ lavg += loss.item() * imgi.shape[0]
1035
+ lavg_per += loss_per.item() * imgi.shape[0]
1036
+
1037
+ nsum += len(img)
1038
+ nbatch += 1
1039
+
1040
+ if iepoch % 5 == 0 or iepoch < 10:
1041
+ lavg = lavg / nsum
1042
+ lavg_per = lavg_per / nsum
1043
+ if test_data is not None or test_files is not None:
1044
+ lavgt, nsum = 0., 0
1045
+ np.random.seed(42)
1046
+ rperm = np.random.choice(np.arange(0, nimg_test),
1047
+ size=(nimg_test_per_epoch,), p=test_probs)
1048
+ inoise = iepoch % nnoise
1049
+ torch.manual_seed(inoise)
1050
+ for ibatch in range(0, nimg_test_per_epoch, batch_size):
1051
+ inds = rperm[ibatch:ibatch + batch_size]
1052
+ if test_data is None:
1053
+ imgs = [
1054
+ np.maximum(0,
1055
+ io.imread(test_files[i])[:nchan]) for i in inds
1056
+ ]
1057
+ lbls = [io.imread(test_labels_files[i])[1:] for i in inds]
1058
+ else:
1059
+ imgs = [test_data[i][:nchan] for i in inds]
1060
+ lbls = [test_labels[i][1:] for i in inds]
1061
+ img, lbl, scale = random_rotate_and_resize_noise(
1062
+ imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise],
1063
+ beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise],
1064
+ iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur,
1065
+ diam_mean=diam_mean, ds_max=ds_max, device=device)
1066
+ loss, loss_per = test_loss(net, img[:, :nchan], net1=net1,
1067
+ img=img[:, nchan:], lbl=lbl, lam=lam)
1068
+
1069
+ lavgt += loss.item() * img.shape[0]
1070
+ nsum += len(img)
1071
+ lavgt = lavgt / nsum
1072
+ denoise_logger.info(
1073
+ "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f"
1074
+ % (iepoch, time.time() - tic, lavg, lavg_per, lavgt,
1075
+ learning_rate[iepoch]))
1076
+ test_losses.append(lavgt)
1077
+ else:
1078
+ denoise_logger.info(
1079
+ "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" %
1080
+ (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch]))
1081
+ train_losses.append(lavg)
1082
+
1083
+ if save_path is not None:
1084
+ if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
1085
+ if save_each: #separate files as model progresses
1086
+ filename0 = str(filename) + f"_epoch_{iepoch:%04d}"
1087
+ else:
1088
+ filename0 = filename
1089
+ denoise_logger.info(f"saving network parameters to {filename0}")
1090
+ net.save_model(filename0)
1091
+ else:
1092
+ filename = save_path
1093
+
1094
+ return filename, train_losses, test_losses
1095
+
1096
+
1097
+ if __name__ == "__main__":
1098
+ import argparse
1099
+ parser = argparse.ArgumentParser(description="cellpose parameters")
1100
+
1101
+ input_img_args = parser.add_argument_group("input image arguments")
1102
+ input_img_args.add_argument("--dir", default=[], type=str,
1103
+ help="folder containing data to run or train on.")
1104
+ input_img_args.add_argument("--img_filter", default=[], type=str,
1105
+ help="end string for images to run on")
1106
+
1107
+ model_args = parser.add_argument_group("model arguments")
1108
+ model_args.add_argument("--pretrained_model", default=[], type=str,
1109
+ help="pretrained denoising model")
1110
+
1111
+ training_args = parser.add_argument_group("training arguments")
1112
+ training_args.add_argument("--test_dir", default=[], type=str,
1113
+ help="folder containing test data (optional)")
1114
+ training_args.add_argument("--file_list", default=[], type=str,
1115
+ help="npy file containing list of train and test files")
1116
+ training_args.add_argument("--seg_model_type", default="cyto2", type=str,
1117
+ help="model to use for seg training loss")
1118
+ training_args.add_argument(
1119
+ "--noise_type", default=[], type=str,
1120
+ help="noise type to use (if input, then other noise params are ignored)")
1121
+ training_args.add_argument("--poisson", default=0.8, type=float,
1122
+ help="fraction of images to add poisson noise to")
1123
+ training_args.add_argument("--beta", default=0.7, type=float,
1124
+ help="scale of poisson noise")
1125
+ training_args.add_argument("--blur", default=0., type=float,
1126
+ help="fraction of images to blur")
1127
+ training_args.add_argument("--gblur", default=1.0, type=float,
1128
+ help="scale of gaussian blurring stddev")
1129
+ training_args.add_argument("--downsample", default=0., type=float,
1130
+ help="fraction of images to downsample")
1131
+ training_args.add_argument("--ds_max", default=7, type=int,
1132
+ help="max downsampling factor")
1133
+ training_args.add_argument("--lam_per", default=1.0, type=float,
1134
+ help="weighting of perceptual loss")
1135
+ training_args.add_argument("--lam_seg", default=1.5, type=float,
1136
+ help="weighting of segmentation loss")
1137
+ training_args.add_argument("--lam_rec", default=0., type=float,
1138
+ help="weighting of reconstruction loss")
1139
+ training_args.add_argument(
1140
+ "--diam_mean", default=30., type=float, help=
1141
+ "mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0"
1142
+ )
1143
+ training_args.add_argument("--learning_rate", default=0.001, type=float,
1144
+ help="learning rate. Default: %(default)s")
1145
+ training_args.add_argument("--n_epochs", default=2000, type=int,
1146
+ help="number of epochs. Default: %(default)s")
1147
+ training_args.add_argument(
1148
+ "--save_each", default=False, action="store_true",
1149
+ help="save each epoch as separate model")
1150
+ training_args.add_argument(
1151
+ "--nimg_per_epoch", default=0, type=int,
1152
+ help="number of images per epoch. Default is length of training images")
1153
+ training_args.add_argument(
1154
+ "--nimg_test_per_epoch", default=0, type=int,
1155
+ help="number of test images per epoch. Default is length of testing images")
1156
+
1157
+ io.logger_setup()
1158
+
1159
+ args = parser.parse_args()
1160
+ lams = [args.lam_per, args.lam_seg, args.lam_rec]
1161
+ print("lam", lams)
1162
+
1163
+ if len(args.noise_type) > 0:
1164
+ noise_type = args.noise_type
1165
+ uniform_blur = False
1166
+ iso = True
1167
+ if noise_type == "poisson":
1168
+ poisson = 0.8
1169
+ blur = 0.
1170
+ downsample = 0.
1171
+ beta = 0.7
1172
+ gblur = 1.0
1173
+ elif noise_type == "blur_expr":
1174
+ poisson = 0.8
1175
+ blur = 0.8
1176
+ downsample = 0.
1177
+ beta = 0.1
1178
+ gblur = 0.5
1179
+ elif noise_type == "blur":
1180
+ poisson = 0.8
1181
+ blur = 0.8
1182
+ downsample = 0.
1183
+ beta = 0.1
1184
+ gblur = 10.0
1185
+ uniform_blur = True
1186
+ elif noise_type == "downsample_expr":
1187
+ poisson = 0.8
1188
+ blur = 0.8
1189
+ downsample = 0.8
1190
+ beta = 0.03
1191
+ gblur = 1.0
1192
+ elif noise_type == "downsample":
1193
+ poisson = 0.8
1194
+ blur = 0.8
1195
+ downsample = 0.8
1196
+ beta = 0.03
1197
+ gblur = 5.0
1198
+ uniform_blur = True
1199
+ elif noise_type == "all":
1200
+ poisson = [0.8, 0.8, 0.8]
1201
+ blur = [0., 0.8, 0.8]
1202
+ downsample = [0., 0., 0.8]
1203
+ beta = [0.7, 0.1, 0.03]
1204
+ gblur = [0., 10.0, 5.0]
1205
+ uniform_blur = True
1206
+ elif noise_type == "aniso":
1207
+ poisson = 0.8
1208
+ blur = 0.8
1209
+ downsample = 0.8
1210
+ beta = 0.1
1211
+ gblur = args.ds_max * 1.5
1212
+ iso = False
1213
+ else:
1214
+ raise ValueError(f"{noise_type} noise_type is not supported")
1215
+ else:
1216
+ poisson, beta = args.poisson, args.beta
1217
+ blur, gblur = args.blur, args.gblur
1218
+ downsample = args.downsample
1219
+
1220
+ pretrained_model = None if len(
1221
+ args.pretrained_model) == 0 else args.pretrained_model
1222
+ model = DenoiseModel(gpu=True, nchan=1, diam_mean=args.diam_mean,
1223
+ pretrained_model=pretrained_model)
1224
+
1225
+ train_data, labels, train_files, train_probs = None, None, None, None
1226
+ test_data, test_labels, test_files, test_probs = None, None, None, None
1227
+ if len(args.file_list) == 0:
1228
+ output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0)
1229
+ images, labels, image_names, test_images, test_labels, image_names_test = output
1230
+ train_data = []
1231
+ for i in range(len(images)):
1232
+ img = images[i].astype("float32")
1233
+ if img.ndim > 2:
1234
+ img = img[0]
1235
+ train_data.append(
1236
+ np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
1237
+ if len(args.test_dir) > 0:
1238
+ test_data = []
1239
+ for i in range(len(test_images)):
1240
+ img = test_images[i].astype("float32")
1241
+ if img.ndim > 2:
1242
+ img = img[0]
1243
+ test_data.append(
1244
+ np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
1245
+ save_path = os.path.join(args.dir, "../models/")
1246
+ else:
1247
+ root = args.dir
1248
+ denoise_logger.info(
1249
+ ">>> using file_list (assumes images are normalized and have flows!)")
1250
+ dat = np.load(args.file_list, allow_pickle=True).item()
1251
+ train_files = dat["train_files"]
1252
+ test_files = dat["test_files"]
1253
+ train_probs = dat["train_probs"] if "train_probs" in dat else None
1254
+ test_probs = dat["test_probs"] if "test_probs" in dat else None
1255
+ if str(train_files[0])[:len(str(root))] != str(root):
1256
+ for i in range(len(train_files)):
1257
+ new_path = root / Path(*train_files[i].parts[-3:])
1258
+ if i == 0:
1259
+ print(f"changing path from {train_files[i]} to {new_path}")
1260
+ train_files[i] = new_path
1261
+
1262
+ for i in range(len(test_files)):
1263
+ new_path = root / Path(*test_files[i].parts[-3:])
1264
+ test_files[i] = new_path
1265
+ save_path = os.path.join(args.dir, "models/")
1266
+
1267
+ os.makedirs(save_path, exist_ok=True)
1268
+
1269
+ nimg_per_epoch = None if args.nimg_per_epoch == 0 else args.nimg_per_epoch
1270
+ nimg_test_per_epoch = None if args.nimg_test_per_epoch == 0 else args.nimg_test_per_epoch
1271
+
1272
+ model_path = train(
1273
+ model.net, train_data=train_data, train_labels=labels, train_files=train_files,
1274
+ test_data=test_data, test_labels=test_labels, test_files=test_files,
1275
+ train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta,
1276
+ blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max,
1277
+ iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs,
1278
+ learning_rate=args.learning_rate,
1279
+ lam=lams,
1280
+ seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch,
1281
+ nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path)
1282
+
1283
+
1284
+ def seg_train_noisy(model, train_data, train_labels, test_data=None, test_labels=None,
1285
+ poisson=0.8, blur=0.0, downsample=0.0, save_path=None,
1286
+ save_every=100, save_each=False, learning_rate=0.2, n_epochs=500,
1287
+ momentum=0.9, weight_decay=0.00001, SGD=True, batch_size=8,
1288
+ nimg_per_epoch=None, diameter=None, rescale=True, z_masking=False,
1289
+ model_name=None):
1290
+ """ train function uses loss function model.loss_fn in models.py
1291
+
1292
+ (data should already be normalized)
1293
+
1294
+ """
1295
+
1296
+ d = datetime.datetime.now()
1297
+
1298
+ model.n_epochs = n_epochs
1299
+ if isinstance(learning_rate, (list, np.ndarray)):
1300
+ if isinstance(learning_rate, np.ndarray) and learning_rate.ndim > 1:
1301
+ raise ValueError("learning_rate.ndim must equal 1")
1302
+ elif len(learning_rate) != n_epochs:
1303
+ raise ValueError(
1304
+ "if learning_rate given as list or np.ndarray it must have length n_epochs"
1305
+ )
1306
+ model.learning_rate = learning_rate
1307
+ model.learning_rate_const = mode(learning_rate)[0][0]
1308
+ else:
1309
+ model.learning_rate_const = learning_rate
1310
+ # set learning rate schedule
1311
+ if SGD:
1312
+ LR = np.linspace(0, model.learning_rate_const, 10)
1313
+ if model.n_epochs > 250:
1314
+ LR = np.append(
1315
+ LR, model.learning_rate_const * np.ones(model.n_epochs - 100))
1316
+ for i in range(10):
1317
+ LR = np.append(LR, LR[-1] / 2 * np.ones(10))
1318
+ else:
1319
+ LR = np.append(
1320
+ LR,
1321
+ model.learning_rate_const * np.ones(max(0, model.n_epochs - 10)))
1322
+ else:
1323
+ LR = model.learning_rate_const * np.ones(model.n_epochs)
1324
+ model.learning_rate = LR
1325
+
1326
+ model.batch_size = batch_size
1327
+ model._set_optimizer(model.learning_rate[0], momentum, weight_decay, SGD)
1328
+ model._set_criterion()
1329
+
1330
+ nimg = len(train_data)
1331
+
1332
+ # compute average cell diameter
1333
+ if diameter is None:
1334
+ diam_train = np.array(
1335
+ [utils.diameters(train_labels[k][0])[0] for k in range(len(train_labels))])
1336
+ diam_train_mean = diam_train[diam_train > 0].mean()
1337
+ model.diam_labels = diam_train_mean
1338
+ if rescale:
1339
+ diam_train[diam_train < 5] = 5.
1340
+ if test_data is not None:
1341
+ diam_test = np.array([
1342
+ utils.diameters(test_labels[k][0])[0]
1343
+ for k in range(len(test_labels))
1344
+ ])
1345
+ diam_test[diam_test < 5] = 5.
1346
+ denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
1347
+ elif rescale:
1348
+ diam_train_mean = diameter
1349
+ model.diam_labels = diameter
1350
+ denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
1351
+ diam_train = diameter * np.ones(len(train_labels), "float32")
1352
+ if test_data is not None:
1353
+ diam_test = diameter * np.ones(len(test_labels), "float32")
1354
+
1355
+ denoise_logger.info(
1356
+ f">>>> mean of training label mask diameters (saved to model) {diam_train_mean:.3f}"
1357
+ )
1358
+ model.net.diam_labels.data = torch.ones(1, device=model.device) * diam_train_mean
1359
+
1360
+ nchan = train_data[0].shape[0]
1361
+ denoise_logger.info(">>>> training network with %d channel input <<<<" % nchan)
1362
+ denoise_logger.info(">>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f" %
1363
+ (model.learning_rate_const, model.batch_size, weight_decay))
1364
+
1365
+ if test_data is not None:
1366
+ denoise_logger.info(f">>>> ntrain = {nimg}, ntest = {len(test_data)}")
1367
+ else:
1368
+ denoise_logger.info(f">>>> ntrain = {nimg}")
1369
+
1370
+ tic = time.time()
1371
+
1372
+ lavg, nsum = 0, 0
1373
+
1374
+ if save_path is not None:
1375
+ _, file_label = os.path.split(save_path)
1376
+ file_path = os.path.join(save_path, "models/")
1377
+
1378
+ if not os.path.exists(file_path):
1379
+ os.makedirs(file_path)
1380
+ else:
1381
+ denoise_logger.warning("WARNING: no save_path given, model not saving")
1382
+
1383
+ ksave = 0
1384
+
1385
+ # get indices for each epoch for training
1386
+ np.random.seed(0)
1387
+ inds_all = np.zeros((0,), "int32")
1388
+ if nimg_per_epoch is None or nimg > nimg_per_epoch:
1389
+ nimg_per_epoch = nimg
1390
+ denoise_logger.info(f">>>> nimg_per_epoch = {nimg_per_epoch}")
1391
+ while len(inds_all) < n_epochs * nimg_per_epoch:
1392
+ rperm = np.random.permutation(nimg)
1393
+ inds_all = np.hstack((inds_all, rperm))
1394
+
1395
+ for iepoch in range(model.n_epochs):
1396
+ if SGD:
1397
+ model._set_learning_rate(model.learning_rate[iepoch])
1398
+ np.random.seed(iepoch)
1399
+ rperm = inds_all[iepoch * nimg_per_epoch:(iepoch + 1) * nimg_per_epoch]
1400
+ for ibatch in range(0, nimg_per_epoch, batch_size):
1401
+ inds = rperm[ibatch:ibatch + batch_size]
1402
+ imgi, lbl, scale = random_rotate_and_resize_noise(
1403
+ [train_data[i] for i in inds], [train_labels[i][1:] for i in inds],
1404
+ poisson=poisson, blur=blur, downsample=downsample,
1405
+ diams=diam_train[inds], diam_mean=model.diam_mean)
1406
+ imgi = imgi[:, :1] # keep noisy only
1407
+ if z_masking:
1408
+ nc = imgi.shape[1]
1409
+ nb = imgi.shape[0]
1410
+ ncmin = (np.random.rand(nb) > 0.25) * (np.random.randint(
1411
+ nc // 2 - 1, size=nb))
1412
+ ncmax = nc - (np.random.rand(nb) > 0.25) * (np.random.randint(
1413
+ nc // 2 - 1, size=nb))
1414
+ for b in range(nb):
1415
+ imgi[b, :ncmin[b]] = 0
1416
+ imgi[b, ncmax[b]:] = 0
1417
+
1418
+ train_loss = model._train_step(imgi, lbl)
1419
+ lavg += train_loss
1420
+ nsum += len(imgi)
1421
+
1422
+ if iepoch % 10 == 0 or iepoch == 5:
1423
+ lavg = lavg / nsum
1424
+ if test_data is not None:
1425
+ lavgt, nsum = 0., 0
1426
+ np.random.seed(42)
1427
+ rperm = np.arange(0, len(test_data), 1, int)
1428
+ for ibatch in range(0, len(test_data), batch_size):
1429
+ inds = rperm[ibatch:ibatch + batch_size]
1430
+ imgi, lbl, scale = random_rotate_and_resize_noise(
1431
+ [test_data[i] for i in inds],
1432
+ [test_labels[i][1:] for i in inds], poisson=poisson, blur=blur,
1433
+ downsample=downsample, diams=diam_test[inds],
1434
+ diam_mean=model.diam_mean)
1435
+ imgi = imgi[:, :1] # keep noisy only
1436
+ test_loss = model._test_eval(imgi, lbl)
1437
+ lavgt += test_loss
1438
+ nsum += len(imgi)
1439
+
1440
+ denoise_logger.info(
1441
+ "Epoch %d, Time %4.1fs, Loss %2.4f, Loss Test %2.4f, LR %2.4f" %
1442
+ (iepoch, time.time() - tic, lavg, lavgt / nsum,
1443
+ model.learning_rate[iepoch]))
1444
+ else:
1445
+ denoise_logger.info(
1446
+ "Epoch %d, Time %4.1fs, Loss %2.4f, LR %2.4f" %
1447
+ (iepoch, time.time() - tic, lavg, model.learning_rate[iepoch]))
1448
+
1449
+ lavg, nsum = 0, 0
1450
+
1451
+ if save_path is not None:
1452
+ if iepoch == model.n_epochs - 1 or iepoch % save_every == 1:
1453
+ # save model at the end
1454
+ if save_each: #separate files as model progresses
1455
+ if model_name is None:
1456
+ filename = "{}_{}_{}_{}".format(
1457
+ model.net_type, file_label,
1458
+ d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch))
1459
+ else:
1460
+ filename = "{}_{}".format(model_name, "epoch_" + str(iepoch))
1461
+ else:
1462
+ if model_name is None:
1463
+ filename = "{}_{}_{}".format(model.net_type, file_label,
1464
+ d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
1465
+ else:
1466
+ filename = model_name
1467
+ filename = os.path.join(file_path, filename)
1468
+ ksave += 1
1469
+ denoise_logger.info(f"saving network parameters to {filename}")
1470
+ model.net.save_model(filename)
1471
+ else:
1472
+ filename = save_path
1473
+
1474
+ return filename
models/seg_post_model/cellpose/dynamics.py ADDED
@@ -0,0 +1,691 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
3
+ """
4
+ import os
5
+ from scipy.ndimage import find_objects, center_of_mass, mean
6
+ import torch
7
+ import numpy as np
8
+ import tifffile
9
+ from tqdm import trange
10
+ import fastremap
11
+
12
+ import logging
13
+
14
+ dynamics_logger = logging.getLogger(__name__)
15
+
16
+ from . import utils
17
+
18
+ import torch
19
+ import torch.nn.functional as F
20
+
21
+ def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
22
+ device=torch.device("cpu")):
23
+ """Runs diffusion on GPU to generate flows for training images or quality control.
24
+
25
+ Args:
26
+ neighbors (torch.Tensor): 9 x pixels in masks.
27
+ meds (torch.Tensor): Mask centers.
28
+ isneighbor (torch.Tensor): Valid neighbor boolean 9 x pixels.
29
+ shape (tuple): Shape of the tensor.
30
+ n_iter (int, optional): Number of iterations. Defaults to 200.
31
+ device (torch.device, optional): Device to run the computation on. Defaults to torch.device("cpu").
32
+
33
+ Returns:
34
+ torch.Tensor: Generated flows.
35
+
36
+ """
37
+ if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps":
38
+ T = torch.zeros(shape, dtype=torch.float, device=device)
39
+ else:
40
+ T = torch.zeros(shape, dtype=torch.double, device=device)
41
+
42
+ for i in range(n_iter):
43
+ T[tuple(meds.T)] += 1
44
+ Tneigh = T[tuple(neighbors)]
45
+ Tneigh *= isneighbor
46
+ T[tuple(neighbors[:, 0])] = Tneigh.mean(axis=0)
47
+ del meds, isneighbor, Tneigh
48
+
49
+ if T.ndim == 2:
50
+ grads = T[neighbors[0, [2, 1, 4, 3]], neighbors[1, [2, 1, 4, 3]]]
51
+ del neighbors
52
+ dy = grads[0] - grads[1]
53
+ dx = grads[2] - grads[3]
54
+ del grads
55
+ mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
56
+ else:
57
+ grads = T[tuple(neighbors[:, 1:])]
58
+ del neighbors
59
+ dz = grads[0] - grads[1]
60
+ dy = grads[2] - grads[3]
61
+ dx = grads[4] - grads[5]
62
+ del grads
63
+ mu_torch = np.stack(
64
+ (dz.cpu().squeeze(0), dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
65
+ return mu_torch
66
+
67
+ def center_of_mass(mask):
68
+ yi, xi = np.nonzero(mask)
69
+ ymean = int(np.round(yi.sum() / len(yi)))
70
+ xmean = int(np.round(xi.sum() / len(xi)))
71
+ if not ((yi==ymean) * (xi==xmean)).sum():
72
+ # center is closest point to (ymean, xmean) within mask
73
+ imin = ((xi - xmean)**2 + (yi - ymean)**2).argmin()
74
+ ymean = yi[imin]
75
+ xmean = xi[imin]
76
+
77
+ return ymean, xmean
78
+
79
+ def get_centers(masks, slices):
80
+ centers = [center_of_mass(masks[slices[i]]==(i+1)) for i in range(len(slices))]
81
+ centers = np.array([np.array([centers[i][0] + slices[i][0].start, centers[i][1] + slices[i][1].start])
82
+ for i in range(len(slices))])
83
+ exts = np.array([(slc[0].stop - slc[0].start) + (slc[1].stop - slc[1].start) + 2 for slc in slices])
84
+ return centers, exts
85
+
86
+
87
+ def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
88
+ """Convert masks to flows using diffusion from center pixel.
89
+
90
+ Center of masks where diffusion starts is defined by pixel closest to median within the mask.
91
+
92
+ Args:
93
+ masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
94
+ device (torch.device, optional): The device to run the computation on. Defaults to torch.device("cpu").
95
+ niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
96
+
97
+ Returns:
98
+ np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
99
+
100
+
101
+ Returns:
102
+ A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY.
103
+ meds_p are cell centers.
104
+ """
105
+ if device is None:
106
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
107
+
108
+ if masks.max() > 0:
109
+ Ly0, Lx0 = masks.shape
110
+ Ly, Lx = Ly0 + 2, Lx0 + 2
111
+
112
+ masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
113
+ masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
114
+ shape = masks_padded.shape
115
+
116
+ ### get mask pixel neighbors
117
+ y, x = torch.nonzero(masks_padded, as_tuple=True)
118
+ y = y.int()
119
+ x = x.int()
120
+ neighbors = torch.zeros((2, 9, y.shape[0]), dtype=torch.int, device=device)
121
+ yxi = [[0, -1, 1, 0, 0, -1, -1, 1, 1], [0, 0, 0, -1, 1, -1, 1, -1, 1]]
122
+ for i in range(9):
123
+ neighbors[0, i] = y + yxi[0][i]
124
+ neighbors[1, i] = x + yxi[1][i]
125
+ isneighbor = torch.ones((9, y.shape[0]), dtype=torch.bool, device=device)
126
+ m0 = masks_padded[neighbors[0, 0], neighbors[1, 0]]
127
+ for i in range(1, 9):
128
+ isneighbor[i] = masks_padded[neighbors[0, i], neighbors[1, i]] == m0
129
+ del m0, masks_padded
130
+
131
+ ### get center-of-mass within cell
132
+ slices = find_objects(masks)
133
+ centers, ext = get_centers(masks, slices)
134
+ meds_p = torch.from_numpy(centers).to(device).long()
135
+ meds_p += 1 # for padding
136
+
137
+ ### run diffusion
138
+ n_iter = 2 * ext.max() if niter is None else niter
139
+ mu = _extend_centers_gpu(neighbors, meds_p, isneighbor, shape, n_iter=n_iter,
140
+ device=device)
141
+ mu = mu.astype("float64")
142
+
143
+ # new normalization
144
+ mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
145
+
146
+ # put into original image
147
+ mu0 = np.zeros((2, Ly0, Lx0))
148
+ mu0[:, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
149
+ else:
150
+ # no masks, return empty flows
151
+ mu0 = np.zeros((2, masks.shape[0], masks.shape[1]))
152
+ return mu0
153
+
154
+ def masks_to_flows_gpu_3d(masks, device=None, niter=None):
155
+ """Convert masks to flows using diffusion from center pixel.
156
+
157
+ Args:
158
+ masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
159
+ device (torch.device, optional): The device to run the computation on. Defaults to None.
160
+ niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
161
+
162
+ Returns:
163
+ np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
164
+
165
+ """
166
+ if device is None:
167
+ device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
168
+
169
+ Lz0, Ly0, Lx0 = masks.shape
170
+ Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2
171
+
172
+ masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
173
+ masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1))
174
+
175
+ # get mask pixel neighbors
176
+ z, y, x = torch.nonzero(masks_padded).T
177
+ neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z))
178
+ neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0)
179
+ neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0)
180
+
181
+ neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0)
182
+
183
+ # get mask centers
184
+ slices = find_objects(masks)
185
+
186
+ centers = np.zeros((masks.max(), 3), "int")
187
+ for i, si in enumerate(slices):
188
+ if si is not None:
189
+ sz, sy, sx = si
190
+ zi, yi, xi = np.nonzero(masks[sz, sy, sx] == (i + 1))
191
+ zi = zi.astype(np.int32) + 1 # add padding
192
+ yi = yi.astype(np.int32) + 1 # add padding
193
+ xi = xi.astype(np.int32) + 1 # add padding
194
+ zmed = np.mean(zi)
195
+ ymed = np.mean(yi)
196
+ xmed = np.mean(xi)
197
+ imin = np.argmin((zi - zmed)**2 + (xi - xmed)**2 + (yi - ymed)**2)
198
+ zmed = zi[imin]
199
+ ymed = yi[imin]
200
+ xmed = xi[imin]
201
+ centers[i, 0] = zmed + sz.start
202
+ centers[i, 1] = ymed + sy.start
203
+ centers[i, 2] = xmed + sx.start
204
+
205
+ # get neighbor validator (not all neighbors are in same mask)
206
+ neighbor_masks = masks_padded[tuple(neighbors)]
207
+ isneighbor = neighbor_masks == neighbor_masks[0]
208
+ ext = np.array(
209
+ [[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1]
210
+ for sz, sy, sx in slices])
211
+ n_iter = 6 * (ext.sum(axis=1)).max() if niter is None else niter
212
+
213
+ # run diffusion
214
+ shape = masks_padded.shape
215
+ mu = _extend_centers_gpu(neighbors, centers, isneighbor, shape, n_iter=n_iter,
216
+ device=device)
217
+ # normalize
218
+ mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
219
+
220
+ # put into original image
221
+ mu0 = np.zeros((3, Lz0, Ly0, Lx0))
222
+ mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
223
+ return mu0
224
+
225
+ def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None,
226
+ return_flows=True):
227
+ """Converts labels (list of masks or flows) to flows for training model.
228
+
229
+ Args:
230
+ labels (list of ND-arrays): The labels to convert. labels[k] can be 2D or 3D. If [3 x Ly x Lx],
231
+ it is assumed that flows were precomputed. Otherwise, labels[k][0] or labels[k] (if 2D)
232
+ is used to create flows and cell probabilities.
233
+ files (list of str, optional): The files to save the flows to. If provided, flows are saved to
234
+ files to be reused. Defaults to None.
235
+ device (str, optional): The device to use for computation. Defaults to None.
236
+ redo_flows (bool, optional): Whether to recompute the flows. Defaults to False.
237
+ niter (int, optional): The number of iterations for computing flows. Defaults to None.
238
+
239
+ Returns:
240
+ list of [4 x Ly x Lx] arrays: The flows for training the model. flows[k][0] is labels[k],
241
+ flows[k][1] is cell distance transform, flows[k][2] is Y flow, flows[k][3] is X flow,
242
+ and flows[k][4] is heat distribution.
243
+ """
244
+ nimg = len(labels)
245
+ if labels[0].ndim < 3:
246
+ labels = [labels[n][np.newaxis, :, :] for n in range(nimg)]
247
+
248
+ flows = []
249
+ # flows need to be recomputed
250
+ if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows:
251
+ dynamics_logger.info("computing flows for labels")
252
+
253
+ # compute flows; labels are fixed here to be unique, so they need to be passed back
254
+ # make sure labels are unique!
255
+ labels = [fastremap.renumber(label, in_place=True)[0] for label in labels]
256
+ iterator = trange if nimg > 1 else range
257
+ for n in iterator(nimg):
258
+ labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0]
259
+ vecn = masks_to_flows_gpu(labels[n][0].astype(int), device=device, niter=niter)
260
+
261
+ # concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
262
+ flow = np.concatenate((labels[n], labels[n] > 0.5, vecn),
263
+ axis=0).astype(np.float32)
264
+ if files is not None:
265
+ file_name = os.path.splitext(files[n])[0]
266
+ tifffile.imwrite(file_name + "_flows.tif", flow)
267
+ if return_flows:
268
+ flows.append(flow)
269
+ else:
270
+ dynamics_logger.info("flows precomputed")
271
+ if return_flows:
272
+ flows = [labels[n].astype(np.float32) for n in range(nimg)]
273
+ return flows
274
+
275
+
276
+ def flow_error(maski, dP_net, device=None):
277
+ """Error in flows from predicted masks vs flows predicted by network run on image.
278
+
279
+ This function serves to benchmark the quality of masks. It works as follows:
280
+ 1. The predicted masks are used to create a flow diagram.
281
+ 2. The mask-flows are compared to the flows that the network predicted.
282
+
283
+ If there is a discrepancy between the flows, it suggests that the mask is incorrect.
284
+ Masks with flow_errors greater than 0.4 are discarded by default. This setting can be
285
+ changed in Cellpose.eval or CellposeModel.eval.
286
+
287
+ Args:
288
+ maski (np.ndarray, int): Masks produced from running dynamics on dP_net, where 0=NO masks; 1,2... are mask labels.
289
+ dP_net (np.ndarray, float): ND flows where dP_net.shape[1:] = maski.shape.
290
+
291
+ Returns:
292
+ A tuple containing (flow_errors, dP_masks): flow_errors (np.ndarray, float): Mean squared error between predicted flows and flows from masks;
293
+ dP_masks (np.ndarray, float): ND flows produced from the predicted masks.
294
+ """
295
+ if dP_net.shape[1:] != maski.shape:
296
+ print("ERROR: net flow is not same size as predicted masks")
297
+ return
298
+
299
+ # flows predicted from estimated masks
300
+ dP_masks = masks_to_flows_gpu(maski, device=device)
301
+ # difference between predicted flows vs mask flows
302
+ flow_errors = np.zeros(maski.max())
303
+ for i in range(dP_masks.shape[0]):
304
+ flow_errors += mean((dP_masks[i] - dP_net[i] / 5.)**2, maski,
305
+ index=np.arange(1,
306
+ maski.max() + 1))
307
+
308
+ return flow_errors, dP_masks
309
+
310
+
311
+ def steps_interp(dP, inds, niter, device=torch.device("cpu")):
312
+ """ Run dynamics of pixels to recover masks in 2D/3D, with interpolation between pixel values.
313
+
314
+ Euler integration of dynamics dP for niter steps.
315
+
316
+ Args:
317
+ p (numpy.ndarray): Array of shape (n_points, 2 or 3) representing the initial pixel locations.
318
+ dP (numpy.ndarray): Array of shape (2, Ly, Lx) or (3, Lz, Ly, Lx) representing the flow field.
319
+ niter (int): Number of iterations to perform.
320
+ device (torch.device, optional): Device to use for computation. Defaults to None.
321
+
322
+ Returns:
323
+ numpy.ndarray: Array of shape (n_points, 2) or (n_points, 3) representing the final pixel locations.
324
+
325
+ Raises:
326
+ None
327
+
328
+ """
329
+
330
+ shape = dP.shape[1:]
331
+ ndim = len(shape)
332
+
333
+ pt = torch.zeros((*[1]*ndim, len(inds[0]), ndim), dtype=torch.float32, device=device)
334
+ im = torch.zeros((1, ndim, *shape), dtype=torch.float32, device=device)
335
+ # Y and X dimensions, flipped X-1, Y-1
336
+ # pt is [1 1 1 3 n_points]
337
+ for n in range(ndim):
338
+ if ndim==3:
339
+ pt[0, 0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32)
340
+ else:
341
+ pt[0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32)
342
+ im[0, ndim - n - 1] = torch.from_numpy(dP[n]).to(device, dtype=torch.float32)
343
+ shape = np.array(shape)[::-1].astype("float") - 1
344
+
345
+ # normalize pt between 0 and 1, normalize the flow
346
+ for k in range(ndim):
347
+ im[:, k] *= 2. / shape[k]
348
+ pt[..., k] /= shape[k]
349
+
350
+ # normalize to between -1 and 1
351
+ pt *= 2
352
+ pt -= 1
353
+
354
+ # dynamics
355
+ for t in range(niter):
356
+ dPt = torch.nn.functional.grid_sample(im, pt, align_corners=False)
357
+ for k in range(ndim): #clamp the final pixel locations
358
+ pt[..., k] = torch.clamp(pt[..., k] + dPt[:, k], -1., 1.)
359
+
360
+ #undo the normalization from before, reverse order of operations
361
+ pt += 1
362
+ pt *= 0.5
363
+ for k in range(ndim):
364
+ pt[..., k] *= shape[k]
365
+
366
+ if ndim==3:
367
+ pt = pt[..., [2, 1, 0]].squeeze()
368
+ pt = pt.unsqueeze(0) if pt.ndim==1 else pt
369
+ return pt.T
370
+ else:
371
+ pt = pt[..., [1, 0]].squeeze()
372
+ pt = pt.unsqueeze(0) if pt.ndim==1 else pt
373
+ return pt.T
374
+
375
+ def follow_flows(dP, inds, niter=200, device=torch.device("cpu")):
376
+ """ Run dynamics to recover masks in 2D or 3D.
377
+
378
+ Pixels are represented as a meshgrid. Only pixels with non-zero cell-probability
379
+ are used (as defined by inds).
380
+
381
+ Args:
382
+ dP (np.ndarray): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
383
+ mask (np.ndarray, optional): Pixel mask to seed masks. Useful when flows have low magnitudes.
384
+ niter (int, optional): Number of iterations of dynamics to run. Default is 200.
385
+ interp (bool, optional): Interpolate during 2D dynamics (not available in 3D). Default is True.
386
+ device (torch.device, optional): Device to use for computation. Default is None.
387
+
388
+ Returns:
389
+ A tuple containing (p, inds): p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx];
390
+ inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
391
+ """
392
+ shape = np.array(dP.shape[1:]).astype(np.int32)
393
+ ndim = len(inds)
394
+
395
+ p = steps_interp(dP, inds, niter, device=device)
396
+
397
+ return p
398
+
399
+
400
+ def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")):
401
+ """Remove masks which have inconsistent flows.
402
+
403
+ Uses metrics.flow_error to compute flows from predicted masks
404
+ and compare flows to predicted flows from the network. Discards
405
+ masks with flow errors greater than the threshold.
406
+
407
+ Args:
408
+ masks (int, 2D or 3D array): Labelled masks, 0=NO masks; 1,2,...=mask labels,
409
+ size [Ly x Lx] or [Lz x Ly x Lx].
410
+ flows (float, 3D or 4D array): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
411
+ threshold (float, optional): Masks with flow error greater than threshold are discarded.
412
+ Default is 0.4.
413
+
414
+ Returns:
415
+ masks (int, 2D or 3D array): Masks with inconsistent flow masks removed,
416
+ 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
417
+ """
418
+ device0 = device
419
+ if masks.size > 10000 * 10000 and (device is not None and device.type == "cuda"):
420
+
421
+ major_version, minor_version = torch.__version__.split(".")[:2]
422
+ torch.cuda.empty_cache()
423
+ if major_version == "1" and int(minor_version) < 10:
424
+ # for PyTorch version lower than 1.10
425
+ def mem_info():
426
+ total_mem = torch.cuda.get_device_properties(device0.index).total_memory
427
+ used_mem = torch.cuda.memory_allocated(device0.index)
428
+ free_mem = total_mem - used_mem
429
+ return total_mem, free_mem
430
+ else:
431
+ # for PyTorch version 1.10 and above
432
+ def mem_info():
433
+ free_mem, total_mem = torch.cuda.mem_get_info(device0.index)
434
+ return total_mem, free_mem
435
+ total_mem, free_mem = mem_info()
436
+ if masks.size * 32 > free_mem:
437
+ dynamics_logger.warning(
438
+ "WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold"
439
+ )
440
+ dynamics_logger.info("turn off QC step with flow_threshold=0 if too slow")
441
+ device0 = torch.device("cpu")
442
+
443
+ merrors, _ = flow_error(masks, flows, device0)
444
+ badi = 1 + (merrors > threshold).nonzero()[0]
445
+ masks[np.isin(masks, badi)] = 0
446
+ return masks
447
+
448
+
449
+ def max_pool1d(h, kernel_size=5, axis=1, out=None):
450
+ """ memory efficient max_pool thanks to Mark Kittisopikul
451
+
452
+ for stride=1, padding=kernel_size//2, requires odd kernel_size >= 3
453
+
454
+ """
455
+ if out is None:
456
+ out = h.clone()
457
+ else:
458
+ out.copy_(h)
459
+
460
+ nd = h.shape[axis]
461
+ k0 = kernel_size // 2
462
+ for d in range(-k0, k0+1):
463
+ if axis==1:
464
+ mv = out[:, max(-d,0):min(nd-d,nd)]
465
+ hv = h[:, max(d,0):min(nd+d,nd)]
466
+ elif axis==2:
467
+ mv = out[:, :, max(-d,0):min(nd-d,nd)]
468
+ hv = h[:, :, max(d,0):min(nd+d,nd)]
469
+ elif axis==3:
470
+ mv = out[:, :, :, max(-d,0):min(nd-d,nd)]
471
+ hv = h[:, :, :, max(d,0):min(nd+d,nd)]
472
+ torch.maximum(mv, hv, out=mv)
473
+ return out
474
+
475
+ def max_pool_nd(h, kernel_size=5):
476
+ """ memory efficient max_pool in 2d or 3d """
477
+ ndim = h.ndim - 1
478
+ hmax = max_pool1d(h, kernel_size=kernel_size, axis=1)
479
+ hmax2 = max_pool1d(hmax, kernel_size=kernel_size, axis=2)
480
+ if ndim==2:
481
+ del hmax
482
+ return hmax2
483
+ else:
484
+ hmax = max_pool1d(hmax2, kernel_size=kernel_size, axis=3, out=hmax)
485
+ del hmax2
486
+ return hmax
487
+
488
+ def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
489
+ """Create masks using pixel convergence after running dynamics.
490
+
491
+ Makes a histogram of final pixel locations p, initializes masks
492
+ at peaks of histogram and extends the masks from the peaks so that
493
+ they include all pixels with more than 2 final pixels p. Discards
494
+ masks with flow errors greater than the threshold.
495
+
496
+ Parameters:
497
+ p (float32, 3D or 4D array): Final locations of each pixel after dynamics,
498
+ size [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
499
+ iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are
500
+ iscell False to stay in their original location.
501
+ rpad (int, optional): Histogram edge padding. Default is 20.
502
+ max_size_fraction (float, optional): Masks larger than max_size_fraction of
503
+ total image size are removed. Default is 0.4.
504
+
505
+ Returns:
506
+ M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed,
507
+ 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
508
+ """
509
+
510
+ ndim = len(shape0)
511
+ device = pt.device
512
+
513
+ rpad = 20
514
+ pt += rpad
515
+ pt = torch.clamp(pt, min=0)
516
+ for i in range(len(pt)):
517
+ pt[i] = torch.clamp(pt[i], max=shape0[i]+rpad-1)
518
+
519
+ # # add extra padding to make divisible by 5
520
+ # shape = tuple((np.ceil((shape0 + 2*rpad)/5) * 5).astype(int))
521
+ shape = tuple(np.array(shape0) + 2*rpad)
522
+
523
+ # sparse coo torch
524
+ coo = torch.sparse_coo_tensor(pt, torch.ones(pt.shape[1], device=pt.device, dtype=torch.int),
525
+ shape)
526
+ h1 = coo.to_dense()
527
+ del coo
528
+
529
+ hmax1 = max_pool_nd(h1.unsqueeze(0), kernel_size=5)
530
+ hmax1 = hmax1.squeeze()
531
+ seeds1 = torch.nonzero((h1 - hmax1 > -1e-6) * (h1 > 10))
532
+ del hmax1
533
+ if len(seeds1) == 0:
534
+ dynamics_logger.warning("no seeds found in get_masks_torch - no masks found.")
535
+ return np.zeros(shape0, dtype="uint16")
536
+
537
+ npts = h1[tuple(seeds1.T)]
538
+ isort1 = npts.argsort()
539
+ seeds1 = seeds1[isort1]
540
+
541
+ n_seeds = len(seeds1)
542
+ h_slc = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
543
+ for k in range(n_seeds):
544
+ slc = tuple([slice(seeds1[k][j]-5, seeds1[k][j]+6) for j in range(ndim)])
545
+ h_slc[k] = h1[slc]
546
+ del h1
547
+ seed_masks = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
548
+ if ndim==2:
549
+ seed_masks[:,5,5] = 1
550
+ else:
551
+ seed_masks[:,5,5,5] = 1
552
+
553
+ for iter in range(5):
554
+ # extend
555
+ seed_masks = max_pool_nd(seed_masks, kernel_size=3)
556
+ seed_masks *= h_slc > 2
557
+ del h_slc
558
+ seeds_new = [tuple((torch.nonzero(seed_masks[k]) + seeds1[k] - 5).T)
559
+ for k in range(n_seeds)]
560
+ del seed_masks
561
+
562
+ dtype = torch.int32 if n_seeds < 2**16 else torch.int64
563
+ M1 = torch.zeros(shape, dtype=dtype, device=device)
564
+ for k in range(n_seeds):
565
+ M1[seeds_new[k]] = 1 + k
566
+
567
+ M1 = M1[tuple(pt)]
568
+ M1 = M1.cpu().numpy()
569
+
570
+ dtype = "uint16" if n_seeds < 2**16 else "uint32"
571
+ M0 = np.zeros(shape0, dtype=dtype)
572
+ M0[inds] = M1
573
+
574
+ # remove big masks
575
+ uniq, counts = fastremap.unique(M0, return_counts=True)
576
+ big = np.prod(shape0) * max_size_fraction
577
+ bigc = uniq[counts > big]
578
+ if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0):
579
+ M0 = fastremap.mask(M0, bigc)
580
+ fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
581
+ M0 = M0.reshape(tuple(shape0))
582
+
583
+ #print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9:.3f} gb")
584
+ return M0
585
+
586
+
587
+ def resize_and_compute_masks(dP, cellprob, niter=200, cellprob_threshold=0.0,
588
+ flow_threshold=0.4, do_3D=False, min_size=15,
589
+ max_size_fraction=0.4, resize=None, device=torch.device("cpu")):
590
+ """Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None.
591
+
592
+ Args:
593
+ dP (numpy.ndarray): The dynamics flow field array.
594
+ cellprob (numpy.ndarray): The cell probability array.
595
+ p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None
596
+ niter (int, optional): The number of iterations for mask computation. Defaults to 200.
597
+ cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0.
598
+ flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4.
599
+ interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
600
+ do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
601
+ min_size (int, optional): The minimum size of the masks. Defaults to 15.
602
+ max_size_fraction (float, optional): Masks larger than max_size_fraction of
603
+ total image size are removed. Default is 0.4.
604
+ resize (tuple, optional): The desired size for resizing the masks. Defaults to None.
605
+ device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu").
606
+
607
+ Returns:
608
+ tuple: A tuple containing the computed masks and the final pixel locations.
609
+ """
610
+ mask = compute_masks(dP, cellprob, niter=niter,
611
+ cellprob_threshold=cellprob_threshold,
612
+ flow_threshold=flow_threshold, do_3D=do_3D,
613
+ max_size_fraction=max_size_fraction,
614
+ device=device)
615
+
616
+ if resize is not None:
617
+ dynamics_logger.warning("Resizing is depricated in v4.0.1+")
618
+
619
+ mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
620
+
621
+ return mask
622
+
623
+
624
+ def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
625
+ flow_threshold=0.4, do_3D=False, min_size=-1,
626
+ max_size_fraction=0.4, device=torch.device("cpu")):
627
+ """Compute masks using dynamics from dP and cellprob.
628
+
629
+ Args:
630
+ dP (numpy.ndarray): The dynamics flow field array.
631
+ cellprob (numpy.ndarray): The cell probability array.
632
+ p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None
633
+ niter (int, optional): The number of iterations for mask computation. Defaults to 200.
634
+ cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0.
635
+ flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4.
636
+ interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
637
+ do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
638
+ min_size (int, optional): The minimum size of the masks. Defaults to 15.
639
+ max_size_fraction (float, optional): Masks larger than max_size_fraction of
640
+ total image size are removed. Default is 0.4.
641
+ device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu").
642
+
643
+ Returns:
644
+ tuple: A tuple containing the computed masks and the final pixel locations.
645
+ """
646
+
647
+ if (cellprob > cellprob_threshold).sum(): #mask at this point is a cell cluster binary map, not labels
648
+ inds = np.nonzero(cellprob > cellprob_threshold)
649
+ if len(inds[0]) == 0:
650
+ dynamics_logger.info("No cell pixels found.")
651
+ shape = cellprob.shape
652
+ mask = np.zeros(shape, "uint16")
653
+ return mask
654
+
655
+ p_final = follow_flows(dP * (cellprob > cellprob_threshold) / 5.,
656
+ inds=inds, niter=niter,
657
+ device=device)
658
+ if not torch.is_tensor(p_final):
659
+ p_final = torch.from_numpy(p_final).to(device, dtype=torch.int)
660
+ else:
661
+ p_final = p_final.int()
662
+ # calculate masks
663
+ if device.type == "mps":
664
+ p_final = p_final.to(torch.device("cpu"))
665
+ mask = get_masks_torch(p_final, inds, dP.shape[1:],
666
+ max_size_fraction=max_size_fraction)
667
+ del p_final
668
+ # flow thresholding factored out of get_masks
669
+ if not do_3D:
670
+ if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0:
671
+ # make sure labels are unique at output of get_masks
672
+ mask = remove_bad_flow_masks(mask, dP, threshold=flow_threshold,
673
+ device=device)
674
+
675
+ if mask.max() < 2**16 and mask.dtype != "uint16":
676
+ mask = mask.astype("uint16")
677
+
678
+ else: # nothing to compute, just make it compatible
679
+ dynamics_logger.info("No cell pixels found.")
680
+ shape = cellprob.shape
681
+ mask = np.zeros(cellprob.shape, "uint16")
682
+ return mask
683
+
684
+ if min_size > 0:
685
+ mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
686
+
687
+ if mask.dtype == np.uint32:
688
+ dynamics_logger.warning(
689
+ "more than 65535 masks in image, masks returned as np.uint32")
690
+
691
+ return mask
models/seg_post_model/cellpose/export.py ADDED
@@ -0,0 +1,405 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Auxiliary module for bioimageio format export
2
+
3
+ Example usage:
4
+
5
+ ```bash
6
+ #!/bin/bash
7
+
8
+ # Define default paths and parameters
9
+ DEFAULT_CHANNELS="1 0"
10
+ DEFAULT_PATH_PRETRAINED_MODEL="/home/qinyu/models/cp/cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995"
11
+ DEFAULT_PATH_README="/home/qinyu/models/cp/README.md"
12
+ DEFAULT_LIST_PATH_COVER_IMAGES="/home/qinyu/images/cp/cellpose_raw_and_segmentation.jpg /home/qinyu/images/cp/cellpose_raw_and_probability.jpg /home/qinyu/images/cp/cellpose_raw.jpg"
13
+ DEFAULT_MODEL_ID="philosophical-panda"
14
+ DEFAULT_MODEL_ICON="🐼"
15
+ DEFAULT_MODEL_VERSION="0.1.0"
16
+ DEFAULT_MODEL_NAME="My Cool Cellpose"
17
+ DEFAULT_MODEL_DOCUMENTATION="A cool Cellpose model trained for my cool dataset."
18
+ DEFAULT_MODEL_AUTHORS='[{"name": "Qin Yu", "affiliation": "EMBL", "github_user": "qin-yu", "orcid": "0000-0002-4652-0795"}]'
19
+ DEFAULT_MODEL_CITE='[{"text": "For more details of the model itself, see the manuscript", "doi": "10.1242/dev.202800", "url": null}]'
20
+ DEFAULT_MODEL_TAGS="cellpose 3d 2d"
21
+ DEFAULT_MODEL_LICENSE="MIT"
22
+ DEFAULT_MODEL_REPO="https://github.com/kreshuklab/go-nuclear"
23
+
24
+ # Run the Python script with default parameters
25
+ python export.py \
26
+ --channels $DEFAULT_CHANNELS \
27
+ --path_pretrained_model "$DEFAULT_PATH_PRETRAINED_MODEL" \
28
+ --path_readme "$DEFAULT_PATH_README" \
29
+ --list_path_cover_images $DEFAULT_LIST_PATH_COVER_IMAGES \
30
+ --model_version "$DEFAULT_MODEL_VERSION" \
31
+ --model_name "$DEFAULT_MODEL_NAME" \
32
+ --model_documentation "$DEFAULT_MODEL_DOCUMENTATION" \
33
+ --model_authors "$DEFAULT_MODEL_AUTHORS" \
34
+ --model_cite "$DEFAULT_MODEL_CITE" \
35
+ --model_tags $DEFAULT_MODEL_TAGS \
36
+ --model_license "$DEFAULT_MODEL_LICENSE" \
37
+ --model_repo "$DEFAULT_MODEL_REPO"
38
+ ```
39
+ """
40
+
41
+ import os
42
+ import sys
43
+ import json
44
+ import argparse
45
+ from pathlib import Path
46
+ from urllib.parse import urlparse
47
+
48
+ import torch
49
+ import numpy as np
50
+
51
+ from cellpose.io import imread
52
+ from cellpose.utils import download_url_to_file
53
+ from cellpose.transforms import pad_image_ND, normalize_img, convert_image
54
+ from cellpose.vit_sam import CPnetBioImageIO
55
+
56
+ from bioimageio.spec.model.v0_5 import (
57
+ ArchitectureFromFileDescr,
58
+ Author,
59
+ AxisId,
60
+ ChannelAxis,
61
+ CiteEntry,
62
+ Doi,
63
+ FileDescr,
64
+ Identifier,
65
+ InputTensorDescr,
66
+ IntervalOrRatioDataDescr,
67
+ LicenseId,
68
+ ModelDescr,
69
+ ModelId,
70
+ OrcidId,
71
+ OutputTensorDescr,
72
+ ParameterizedSize,
73
+ PytorchStateDictWeightsDescr,
74
+ SizeReference,
75
+ SpaceInputAxis,
76
+ SpaceOutputAxis,
77
+ TensorId,
78
+ TorchscriptWeightsDescr,
79
+ Version,
80
+ WeightsDescr,
81
+ )
82
+ # Define ARBITRARY_SIZE if it is not available in the module
83
+ try:
84
+ from bioimageio.spec.model.v0_5 import ARBITRARY_SIZE
85
+ except ImportError:
86
+ ARBITRARY_SIZE = ParameterizedSize(min=1, step=1)
87
+
88
+ from bioimageio.spec.common import HttpUrl
89
+ from bioimageio.spec import save_bioimageio_package
90
+ from bioimageio.core import test_model
91
+
92
+ DEFAULT_CHANNELS = [2, 1]
93
+ DEFAULT_NORMALIZE_PARAMS = {
94
+ "axis": -1,
95
+ "lowhigh": None,
96
+ "percentile": None,
97
+ "normalize": True,
98
+ "norm3D": False,
99
+ "sharpen_radius": 0,
100
+ "smooth_radius": 0,
101
+ "tile_norm_blocksize": 0,
102
+ "tile_norm_smooth3D": 1,
103
+ "invert": False,
104
+ }
105
+ IMAGE_URL = "http://www.cellpose.org/static/data/rgb_3D.tif"
106
+
107
+
108
+ def download_and_normalize_image(path_dir_temp, channels=DEFAULT_CHANNELS):
109
+ """
110
+ Download and normalize image.
111
+ """
112
+ filename = os.path.basename(urlparse(IMAGE_URL).path)
113
+ path_image = path_dir_temp / filename
114
+ if not path_image.exists():
115
+ sys.stderr.write(f'Downloading: "{IMAGE_URL}" to {path_image}\n')
116
+ download_url_to_file(IMAGE_URL, path_image)
117
+ img = imread(path_image).astype(np.float32)
118
+ img = convert_image(img, channels, channel_axis=1, z_axis=0, do_3D=False, nchan=2)
119
+ img = normalize_img(img, **DEFAULT_NORMALIZE_PARAMS)
120
+ img = np.transpose(img, (0, 3, 1, 2))
121
+ img, _, _ = pad_image_ND(img)
122
+ return img
123
+
124
+
125
+ def load_bioimageio_cpnet_model(path_model_weight, nchan=2):
126
+ cpnet_kwargs = {
127
+ "nout": 3,
128
+ }
129
+ cpnet_biio = CPnetBioImageIO(**cpnet_kwargs)
130
+ state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True)
131
+ cpnet_biio.load_state_dict(state_dict_cuda)
132
+ cpnet_biio.eval() # crucial for the prediction results
133
+ return cpnet_biio, cpnet_kwargs
134
+
135
+
136
+ def descr_gen_input(path_test_input, nchan=2):
137
+ input_axes = [
138
+ SpaceInputAxis(id=AxisId("z"), size=ARBITRARY_SIZE),
139
+ ChannelAxis(channel_names=[Identifier(f"c{i+1}") for i in range(nchan)]),
140
+ SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=16, step=16)),
141
+ SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=16, step=16)),
142
+ ]
143
+ data_descr = IntervalOrRatioDataDescr(type="float32")
144
+ path_test_input = Path(path_test_input)
145
+ descr_input = InputTensorDescr(
146
+ id=TensorId("raw"),
147
+ axes=input_axes,
148
+ test_tensor=FileDescr(source=path_test_input),
149
+ data=data_descr,
150
+ )
151
+ return descr_input
152
+
153
+
154
+ def descr_gen_output_flow(path_test_output):
155
+ output_axes_output_tensor = [
156
+ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
157
+ ChannelAxis(channel_names=[Identifier("flow1"), Identifier("flow2"), Identifier("flow3")]),
158
+ SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))),
159
+ SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x"))),
160
+ ]
161
+ path_test_output = Path(path_test_output)
162
+ descr_output = OutputTensorDescr(
163
+ id=TensorId("flow"),
164
+ axes=output_axes_output_tensor,
165
+ test_tensor=FileDescr(source=path_test_output),
166
+ )
167
+ return descr_output
168
+
169
+
170
+ def descr_gen_output_downsampled(path_dir_temp, nbase=None):
171
+ if nbase is None:
172
+ nbase = [32, 64, 128, 256]
173
+
174
+ output_axes_downsampled_tensors = [
175
+ [
176
+ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
177
+ ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(base)]),
178
+ SpaceOutputAxis(
179
+ id=AxisId("y"),
180
+ size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y")),
181
+ scale=2**offset,
182
+ ),
183
+ SpaceOutputAxis(
184
+ id=AxisId("x"),
185
+ size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")),
186
+ scale=2**offset,
187
+ ),
188
+ ]
189
+ for offset, base in enumerate(nbase)
190
+ ]
191
+ path_downsampled_tensors = [
192
+ Path(path_dir_temp / f"test_downsampled_{i}.npy") for i in range(len(output_axes_downsampled_tensors))
193
+ ]
194
+ descr_output_downsampled_tensors = [
195
+ OutputTensorDescr(
196
+ id=TensorId(f"downsampled_{i}"),
197
+ axes=axes,
198
+ test_tensor=FileDescr(source=path),
199
+ )
200
+ for i, (axes, path) in enumerate(zip(output_axes_downsampled_tensors, path_downsampled_tensors))
201
+ ]
202
+ return descr_output_downsampled_tensors
203
+
204
+
205
+ def descr_gen_output_style(path_test_style, nchannel=256):
206
+ output_axes_style_tensor = [
207
+ SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
208
+ ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(nchannel)]),
209
+ ]
210
+ path_style_tensor = Path(path_test_style)
211
+ descr_output_style_tensor = OutputTensorDescr(
212
+ id=TensorId("style"),
213
+ axes=output_axes_style_tensor,
214
+ test_tensor=FileDescr(source=path_style_tensor),
215
+ )
216
+ return descr_output_style_tensor
217
+
218
+
219
+ def descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper=None):
220
+ if path_cpnet_wrapper is None:
221
+ path_cpnet_wrapper = Path(__file__).parent / "resnet_torch.py"
222
+ pytorch_architecture = ArchitectureFromFileDescr(
223
+ callable=Identifier("CPnetBioImageIO"),
224
+ source=Path(path_cpnet_wrapper),
225
+ kwargs=cpnet_kwargs,
226
+ )
227
+ return pytorch_architecture
228
+
229
+
230
+ def descr_gen_documentation(path_doc, markdown_text):
231
+ with open(path_doc, "w") as f:
232
+ f.write(markdown_text)
233
+
234
+
235
+ def package_to_bioimageio(
236
+ path_pretrained_model,
237
+ path_save_trace,
238
+ path_readme,
239
+ list_path_cover_images,
240
+ descr_input,
241
+ descr_output,
242
+ descr_output_downsampled_tensors,
243
+ descr_output_style_tensor,
244
+ pytorch_version,
245
+ pytorch_architecture,
246
+ model_id,
247
+ model_icon,
248
+ model_version,
249
+ model_name,
250
+ model_documentation,
251
+ model_authors,
252
+ model_cite,
253
+ model_tags,
254
+ model_license,
255
+ model_repo,
256
+ ):
257
+ """Package model description to BioImage.IO format."""
258
+ my_model_descr = ModelDescr(
259
+ id=ModelId(model_id) if model_id is not None else None,
260
+ id_emoji=model_icon,
261
+ version=Version(model_version),
262
+ name=model_name,
263
+ description=model_documentation,
264
+ authors=[
265
+ Author(
266
+ name=author["name"],
267
+ affiliation=author["affiliation"],
268
+ github_user=author["github_user"],
269
+ orcid=OrcidId(author["orcid"]),
270
+ )
271
+ for author in model_authors
272
+ ],
273
+ cite=[CiteEntry(text=cite["text"], doi=Doi(cite["doi"]), url=cite["url"]) for cite in model_cite],
274
+ covers=[Path(img) for img in list_path_cover_images],
275
+ license=LicenseId(model_license),
276
+ tags=model_tags,
277
+ documentation=Path(path_readme),
278
+ git_repo=HttpUrl(model_repo),
279
+ inputs=[descr_input],
280
+ outputs=[descr_output, descr_output_style_tensor] + descr_output_downsampled_tensors,
281
+ weights=WeightsDescr(
282
+ pytorch_state_dict=PytorchStateDictWeightsDescr(
283
+ source=Path(path_pretrained_model),
284
+ architecture=pytorch_architecture,
285
+ pytorch_version=pytorch_version,
286
+ ),
287
+ torchscript=TorchscriptWeightsDescr(
288
+ source=Path(path_save_trace),
289
+ pytorch_version=pytorch_version,
290
+ parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights.
291
+ ),
292
+ ),
293
+ )
294
+
295
+ return my_model_descr
296
+
297
+
298
+ def parse_args():
299
+ # fmt: off
300
+ parser = argparse.ArgumentParser(description="BioImage.IO model packaging for Cellpose")
301
+ parser.add_argument("--channels", nargs=2, default=[2, 1], type=int, help="Cyto-only = [2, 0], Cyto + Nuclei = [2, 1], Nuclei-only = [1, 0]")
302
+ parser.add_argument("--path_pretrained_model", required=True, type=str, help="Path to pretrained model file, e.g., cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995")
303
+ parser.add_argument("--path_readme", required=True, type=str, help="Path to README file")
304
+ parser.add_argument("--list_path_cover_images", nargs='+', required=True, type=str, help="List of paths to cover images")
305
+ parser.add_argument("--model_id", type=str, help="Model ID, provide if already exists", default=None)
306
+ parser.add_argument("--model_icon", type=str, help="Model icon, provide if already exists", default=None)
307
+ parser.add_argument("--model_version", required=True, type=str, help="Model version, new model should be 0.1.0")
308
+ parser.add_argument("--model_name", required=True, type=str, help="Model name, e.g., My Cool Cellpose")
309
+ parser.add_argument("--model_documentation", required=True, type=str, help="Model documentation, e.g., A cool Cellpose model trained for my cool dataset.")
310
+ parser.add_argument("--model_authors", required=True, type=str, help="Model authors in JSON format, e.g., '[{\"name\": \"Qin Yu\", \"affiliation\": \"EMBL\", \"github_user\": \"qin-yu\", \"orcid\": \"0000-0002-4652-0795\"}]'")
311
+ parser.add_argument("--model_cite", required=True, type=str, help="Model citation in JSON format, e.g., '[{\"text\": \"For more details of the model itself, see the manuscript\", \"doi\": \"10.1242/dev.202800\", \"url\": null}]'")
312
+ parser.add_argument("--model_tags", nargs='+', required=True, type=str, help="Model tags, e.g., cellpose 3d 2d")
313
+ parser.add_argument("--model_license", required=True, type=str, help="Model license, e.g., MIT")
314
+ parser.add_argument("--model_repo", required=True, type=str, help="Model repository URL")
315
+ return parser.parse_args()
316
+ # fmt: on
317
+
318
+
319
+ def main():
320
+ args = parse_args()
321
+
322
+ # Parse user-provided paths and arguments
323
+ channels = args.channels
324
+ model_cite = json.loads(args.model_cite)
325
+ model_authors = json.loads(args.model_authors)
326
+
327
+ path_readme = Path(args.path_readme)
328
+ path_pretrained_model = Path(args.path_pretrained_model)
329
+ list_path_cover_images = [Path(path_image) for path_image in args.list_path_cover_images]
330
+
331
+ # Auto-generated paths
332
+ path_cpnet_wrapper = Path(__file__).resolve().parent / "resnet_torch.py"
333
+ path_dir_temp = Path(__file__).resolve().parent.parent / "models" / path_pretrained_model.stem
334
+ path_dir_temp.mkdir(parents=True, exist_ok=True)
335
+
336
+ path_save_trace = path_dir_temp / "cp_traced.pt"
337
+ path_test_input = path_dir_temp / "test_input.npy"
338
+ path_test_output = path_dir_temp / "test_output.npy"
339
+ path_test_style = path_dir_temp / "test_style.npy"
340
+ path_bioimageio_package = path_dir_temp / "cellpose_model.zip"
341
+
342
+ # Download test input image
343
+ img_np = download_and_normalize_image(path_dir_temp, channels=channels)
344
+ np.save(path_test_input, img_np)
345
+ img = torch.tensor(img_np).float()
346
+
347
+ # Load model
348
+ cpnet_biio, cpnet_kwargs = load_bioimageio_cpnet_model(path_pretrained_model)
349
+
350
+ # Test model and save output
351
+ tuple_output_tensor = cpnet_biio(img)
352
+ np.save(path_test_output, tuple_output_tensor[0].detach().numpy())
353
+ np.save(path_test_style, tuple_output_tensor[1].detach().numpy())
354
+ for i, t in enumerate(tuple_output_tensor[2:]):
355
+ np.save(path_dir_temp / f"test_downsampled_{i}.npy", t.detach().numpy())
356
+
357
+ # Save traced model
358
+ model_traced = torch.jit.trace(cpnet_biio, img)
359
+ model_traced.save(path_save_trace)
360
+
361
+ # Generate model description
362
+ descr_input = descr_gen_input(path_test_input)
363
+ descr_output = descr_gen_output_flow(path_test_output)
364
+ descr_output_downsampled_tensors = descr_gen_output_downsampled(path_dir_temp, nbase=cpnet_biio.nbase[1:])
365
+ descr_output_style_tensor = descr_gen_output_style(path_test_style, cpnet_biio.nbase[-1])
366
+ pytorch_version = Version(torch.__version__)
367
+ pytorch_architecture = descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper)
368
+
369
+ # Package model
370
+ my_model_descr = package_to_bioimageio(
371
+ path_pretrained_model,
372
+ path_save_trace,
373
+ path_readme,
374
+ list_path_cover_images,
375
+ descr_input,
376
+ descr_output,
377
+ descr_output_downsampled_tensors,
378
+ descr_output_style_tensor,
379
+ pytorch_version,
380
+ pytorch_architecture,
381
+ args.model_id,
382
+ args.model_icon,
383
+ args.model_version,
384
+ args.model_name,
385
+ args.model_documentation,
386
+ model_authors,
387
+ model_cite,
388
+ args.model_tags,
389
+ args.model_license,
390
+ args.model_repo,
391
+ )
392
+
393
+ # Test model
394
+ summary = test_model(my_model_descr, weight_format="pytorch_state_dict")
395
+ summary.display()
396
+ summary = test_model(my_model_descr, weight_format="torchscript")
397
+ summary.display()
398
+
399
+ # Save BioImage.IO package
400
+ package_path = save_bioimageio_package(my_model_descr, output_path=Path(path_bioimageio_package))
401
+ print("package path:", package_path)
402
+
403
+
404
+ if __name__ == "__main__":
405
+ main()