Spaces:
Paused
Paused
Upload 78 files
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +2 -0
- _utils/attn_utils.py +592 -0
- _utils/attn_utils_new.py +610 -0
- _utils/config.yaml +15 -0
- _utils/example_config.yaml +20 -0
- _utils/load_models.py +16 -0
- _utils/load_track_data.py +104 -0
- _utils/misc_helper.py +37 -0
- _utils/seg_eval.py +61 -0
- _utils/track_args.py +157 -0
- config.py +44 -0
- counting.py +337 -0
- example_imgs/1977_Well_F-5_Field_1.png +3 -0
- example_imgs/1977_Well_F-5_Field_1_seg.png +3 -0
- models/.DS_Store +0 -0
- models/enc_model/__init__.py +0 -0
- models/enc_model/backbone.py +64 -0
- models/enc_model/loca.py +232 -0
- models/enc_model/loca_args.py +44 -0
- models/enc_model/mlp.py +23 -0
- models/enc_model/ope.py +245 -0
- models/enc_model/positional_encoding.py +30 -0
- models/enc_model/regression_head.py +92 -0
- models/enc_model/transformer.py +94 -0
- models/enc_model/unet_parts.py +77 -0
- models/model.py +991 -0
- models/seg_post_model/cellpose/__init__.py +1 -0
- models/seg_post_model/cellpose/__main__.py +272 -0
- models/seg_post_model/cellpose/cli.py +240 -0
- models/seg_post_model/cellpose/core.py +322 -0
- models/seg_post_model/cellpose/denoise.py +1474 -0
- models/seg_post_model/cellpose/dynamics.py +691 -0
- models/seg_post_model/cellpose/export.py +405 -0
- models/seg_post_model/cellpose/gui/gui.py +2007 -0
- models/seg_post_model/cellpose/gui/gui3d.py +667 -0
- models/seg_post_model/cellpose/gui/guihelpwindowtext.html +143 -0
- models/seg_post_model/cellpose/gui/guiparts.py +793 -0
- models/seg_post_model/cellpose/gui/guitrainhelpwindowtext.html +25 -0
- models/seg_post_model/cellpose/gui/io.py +634 -0
- models/seg_post_model/cellpose/gui/make_train.py +107 -0
- models/seg_post_model/cellpose/gui/menus.py +145 -0
- models/seg_post_model/cellpose/io.py +816 -0
- models/seg_post_model/cellpose/metrics.py +205 -0
- models/seg_post_model/cellpose/models.py +524 -0
- models/seg_post_model/cellpose/plot.py +281 -0
- models/seg_post_model/cellpose/transforms.py +1261 -0
- models/seg_post_model/cellpose/utils.py +667 -0
- models/seg_post_model/cellpose/version.py +18 -0
- models/seg_post_model/cellpose/vit_sam.py +195 -0
- models/seg_post_model/cellpose/vit_sam_new.py +197 -0
.gitattributes
CHANGED
|
@@ -35,3 +35,5 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
003_img.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
1977_Well_F-5_Field_1.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
003_img.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
1977_Well_F-5_Field_1.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
example_imgs/1977_Well_F-5_Field_1_seg.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
example_imgs/1977_Well_F-5_Field_1.png filter=lfs diff=lfs merge=lfs -text
|
_utils/attn_utils.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from IPython.display import display
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing import Union, Tuple, List
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
import math
|
| 11 |
+
from torch import nn, einsum
|
| 12 |
+
from inspect import isfunction
|
| 13 |
+
from diffusers.utils import logging
|
| 14 |
+
try:
|
| 15 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
| 16 |
+
except:
|
| 17 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from diffusers.models.cross_attention import CrossAttention
|
| 21 |
+
except:
|
| 22 |
+
from diffusers.models.attention_processor import Attention as CrossAttention
|
| 23 |
+
|
| 24 |
+
MAX_NUM_WORDS = 77
|
| 25 |
+
LOW_RESOURCE = False
|
| 26 |
+
|
| 27 |
+
class CountingCrossAttnProcessor1:
|
| 28 |
+
|
| 29 |
+
def __init__(self, attnstore, place_in_unet):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.attnstore = attnstore
|
| 32 |
+
self.place_in_unet = place_in_unet
|
| 33 |
+
|
| 34 |
+
def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 35 |
+
batch_size, sequence_length, dim = hidden_states.shape
|
| 36 |
+
h = attn_layer.heads
|
| 37 |
+
q = attn_layer.to_q(hidden_states)
|
| 38 |
+
is_cross = encoder_hidden_states is not None
|
| 39 |
+
context = encoder_hidden_states if is_cross else hidden_states
|
| 40 |
+
k = attn_layer.to_k(context)
|
| 41 |
+
v = attn_layer.to_v(context)
|
| 42 |
+
# q = attn_layer.reshape_heads_to_batch_dim(q)
|
| 43 |
+
# k = attn_layer.reshape_heads_to_batch_dim(k)
|
| 44 |
+
# v = attn_layer.reshape_heads_to_batch_dim(v)
|
| 45 |
+
# q = attn_layer.head_to_batch_dim(q)
|
| 46 |
+
# k = attn_layer.head_to_batch_dim(k)
|
| 47 |
+
# v = attn_layer.head_to_batch_dim(v)
|
| 48 |
+
q = self.head_to_batch_dim(q, h)
|
| 49 |
+
k = self.head_to_batch_dim(k, h)
|
| 50 |
+
v = self.head_to_batch_dim(v, h)
|
| 51 |
+
|
| 52 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale
|
| 53 |
+
|
| 54 |
+
if attention_mask is not None:
|
| 55 |
+
attention_mask = attention_mask.reshape(batch_size, -1)
|
| 56 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 57 |
+
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
|
| 58 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
| 59 |
+
|
| 60 |
+
# attention, what we cannot get enough of
|
| 61 |
+
attn_ = sim.softmax(dim=-1).clone()
|
| 62 |
+
# softmax = nn.Softmax(dim=-1)
|
| 63 |
+
# attn_ = softmax(sim)
|
| 64 |
+
self.attnstore(attn_, is_cross, self.place_in_unet)
|
| 65 |
+
out = torch.einsum("b i j, b j d -> b i d", attn_, v)
|
| 66 |
+
# out = attn_layer.batch_to_head_dim(out)
|
| 67 |
+
out = self.batch_to_head_dim(out, h)
|
| 68 |
+
|
| 69 |
+
if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
|
| 70 |
+
to_out = attn_layer.to_out[0]
|
| 71 |
+
else:
|
| 72 |
+
to_out = attn_layer.to_out
|
| 73 |
+
|
| 74 |
+
out = to_out(out)
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
def batch_to_head_dim(self, tensor, head_size):
|
| 78 |
+
# head_size = self.heads
|
| 79 |
+
batch_size, seq_len, dim = tensor.shape
|
| 80 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 81 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
| 82 |
+
return tensor
|
| 83 |
+
|
| 84 |
+
def head_to_batch_dim(self, tensor, head_size, out_dim=3):
|
| 85 |
+
# head_size = self.heads
|
| 86 |
+
batch_size, seq_len, dim = tensor.shape
|
| 87 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
| 88 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
| 89 |
+
|
| 90 |
+
if out_dim == 3:
|
| 91 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
| 92 |
+
|
| 93 |
+
return tensor
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def register_attention_control(model, controller):
|
| 97 |
+
|
| 98 |
+
attn_procs = {}
|
| 99 |
+
cross_att_count = 0
|
| 100 |
+
for name in model.unet.attn_processors.keys():
|
| 101 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
|
| 102 |
+
if name.startswith("mid_block"):
|
| 103 |
+
hidden_size = model.unet.config.block_out_channels[-1]
|
| 104 |
+
place_in_unet = "mid"
|
| 105 |
+
elif name.startswith("up_blocks"):
|
| 106 |
+
block_id = int(name[len("up_blocks.")])
|
| 107 |
+
hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
|
| 108 |
+
place_in_unet = "up"
|
| 109 |
+
elif name.startswith("down_blocks"):
|
| 110 |
+
block_id = int(name[len("down_blocks.")])
|
| 111 |
+
hidden_size = model.unet.config.block_out_channels[block_id]
|
| 112 |
+
place_in_unet = "down"
|
| 113 |
+
else:
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
cross_att_count += 1
|
| 117 |
+
# attn_procs[name] = AttendExciteCrossAttnProcessor(
|
| 118 |
+
# attnstore=controller, place_in_unet=place_in_unet
|
| 119 |
+
# )
|
| 120 |
+
attn_procs[name] = CountingCrossAttnProcessor1(
|
| 121 |
+
attnstore=controller, place_in_unet=place_in_unet
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
model.unet.set_attn_processor(attn_procs)
|
| 125 |
+
controller.num_att_layers = cross_att_count
|
| 126 |
+
|
| 127 |
+
def register_hier_output(model):
|
| 128 |
+
self = model.unet
|
| 129 |
+
from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding
|
| 130 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 131 |
+
def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None,
|
| 132 |
+
attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None,
|
| 133 |
+
mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True):
|
| 134 |
+
|
| 135 |
+
out_list = []
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 139 |
+
|
| 140 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 141 |
+
forward_upsample_size = False
|
| 142 |
+
upsample_size = None
|
| 143 |
+
|
| 144 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 145 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 146 |
+
forward_upsample_size = True
|
| 147 |
+
|
| 148 |
+
if attention_mask is not None:
|
| 149 |
+
# assume that mask is expressed as:
|
| 150 |
+
# (1 = keep, 0 = discard)
|
| 151 |
+
# convert mask into a bias that can be added to attention scores:
|
| 152 |
+
# (keep = +0, discard = -10000.0)
|
| 153 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 154 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 155 |
+
|
| 156 |
+
if encoder_attention_mask is not None:
|
| 157 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 158 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 159 |
+
|
| 160 |
+
if self.config.center_input_sample:
|
| 161 |
+
sample = 2 * sample - 1.0
|
| 162 |
+
|
| 163 |
+
timesteps = timestep
|
| 164 |
+
if not torch.is_tensor(timesteps):
|
| 165 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 166 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 167 |
+
is_mps = sample.device.type == "mps"
|
| 168 |
+
if isinstance(timestep, float):
|
| 169 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 170 |
+
else:
|
| 171 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 172 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 173 |
+
elif len(timesteps.shape) == 0:
|
| 174 |
+
timesteps = timesteps[None].to(sample.device)
|
| 175 |
+
|
| 176 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 177 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 178 |
+
|
| 179 |
+
t_emb = self.time_proj(timesteps)
|
| 180 |
+
|
| 181 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 182 |
+
|
| 183 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 184 |
+
aug_emb = None
|
| 185 |
+
|
| 186 |
+
if self.class_embedding is not None:
|
| 187 |
+
if class_labels is None:
|
| 188 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 189 |
+
|
| 190 |
+
if self.config.class_embed_type == "timestep":
|
| 191 |
+
class_labels = self.time_proj(class_labels)
|
| 192 |
+
|
| 193 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 194 |
+
# there might be better ways to encapsulate this.
|
| 195 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
| 196 |
+
|
| 197 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
| 198 |
+
|
| 199 |
+
if self.config.class_embeddings_concat:
|
| 200 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
| 201 |
+
else:
|
| 202 |
+
emb = emb + class_emb
|
| 203 |
+
|
| 204 |
+
if self.config.addition_embed_type == "text":
|
| 205 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 206 |
+
elif self.config.addition_embed_type == "text_image":
|
| 207 |
+
# Kandinsky 2.1 - style
|
| 208 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 214 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
| 215 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
| 216 |
+
elif self.config.addition_embed_type == "text_time":
|
| 217 |
+
# SDXL - style
|
| 218 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 219 |
+
raise ValueError(
|
| 220 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 221 |
+
)
|
| 222 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 223 |
+
if "time_ids" not in added_cond_kwargs:
|
| 224 |
+
raise ValueError(
|
| 225 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 226 |
+
)
|
| 227 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 228 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 229 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 230 |
+
|
| 231 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 232 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 233 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 234 |
+
elif self.config.addition_embed_type == "image":
|
| 235 |
+
# Kandinsky 2.2 - style
|
| 236 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 239 |
+
)
|
| 240 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 241 |
+
aug_emb = self.add_embedding(image_embs)
|
| 242 |
+
elif self.config.addition_embed_type == "image_hint":
|
| 243 |
+
# Kandinsky 2.2 - style
|
| 244 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
| 245 |
+
raise ValueError(
|
| 246 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
| 247 |
+
)
|
| 248 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 249 |
+
hint = added_cond_kwargs.get("hint")
|
| 250 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
| 251 |
+
sample = torch.cat([sample, hint], dim=1)
|
| 252 |
+
|
| 253 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 254 |
+
|
| 255 |
+
if self.time_embed_act is not None:
|
| 256 |
+
emb = self.time_embed_act(emb)
|
| 257 |
+
|
| 258 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
| 259 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
| 260 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
| 261 |
+
# Kadinsky 2.1 - style
|
| 262 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 268 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
| 269 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
| 270 |
+
# Kandinsky 2.2 - style
|
| 271 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 272 |
+
raise ValueError(
|
| 273 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 274 |
+
)
|
| 275 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 276 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
| 277 |
+
# 2. pre-process
|
| 278 |
+
sample = self.conv_in(sample) # 1, 320, 64, 64
|
| 279 |
+
|
| 280 |
+
# 2.5 GLIGEN position net
|
| 281 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
| 282 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
| 283 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
| 284 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
| 285 |
+
|
| 286 |
+
# 3. down
|
| 287 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
| 288 |
+
|
| 289 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
| 290 |
+
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
| 291 |
+
|
| 292 |
+
down_block_res_samples = (sample,)
|
| 293 |
+
|
| 294 |
+
for downsample_block in self.down_blocks:
|
| 295 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 296 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
| 297 |
+
additional_residuals = {}
|
| 298 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
| 299 |
+
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
| 300 |
+
|
| 301 |
+
sample, res_samples = downsample_block(
|
| 302 |
+
hidden_states=sample,
|
| 303 |
+
temb=emb,
|
| 304 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 305 |
+
attention_mask=attention_mask,
|
| 306 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 307 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 308 |
+
**additional_residuals,
|
| 309 |
+
)
|
| 310 |
+
else:
|
| 311 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
| 312 |
+
|
| 313 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
| 314 |
+
sample += down_block_additional_residuals.pop(0)
|
| 315 |
+
|
| 316 |
+
down_block_res_samples += res_samples
|
| 317 |
+
|
| 318 |
+
if is_controlnet:
|
| 319 |
+
new_down_block_res_samples = ()
|
| 320 |
+
|
| 321 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 322 |
+
down_block_res_samples, down_block_additional_residuals
|
| 323 |
+
):
|
| 324 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 325 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
| 326 |
+
|
| 327 |
+
down_block_res_samples = new_down_block_res_samples
|
| 328 |
+
|
| 329 |
+
# 4. mid
|
| 330 |
+
if self.mid_block is not None:
|
| 331 |
+
sample = self.mid_block(
|
| 332 |
+
sample,
|
| 333 |
+
emb,
|
| 334 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 335 |
+
attention_mask=attention_mask,
|
| 336 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 337 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 338 |
+
)
|
| 339 |
+
# To support T2I-Adapter-XL
|
| 340 |
+
if (
|
| 341 |
+
is_adapter
|
| 342 |
+
and len(down_block_additional_residuals) > 0
|
| 343 |
+
and sample.shape == down_block_additional_residuals[0].shape
|
| 344 |
+
):
|
| 345 |
+
sample += down_block_additional_residuals.pop(0)
|
| 346 |
+
|
| 347 |
+
if is_controlnet:
|
| 348 |
+
sample = sample + mid_block_additional_residual
|
| 349 |
+
|
| 350 |
+
# 5. up
|
| 351 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 352 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 353 |
+
|
| 354 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 355 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 356 |
+
|
| 357 |
+
# if we have not reached the final block and need to forward the
|
| 358 |
+
# upsample size, we do it here
|
| 359 |
+
if not is_final_block and forward_upsample_size:
|
| 360 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 361 |
+
|
| 362 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 363 |
+
sample = upsample_block(
|
| 364 |
+
hidden_states=sample,
|
| 365 |
+
temb=emb,
|
| 366 |
+
res_hidden_states_tuple=res_samples,
|
| 367 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 368 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 369 |
+
upsample_size=upsample_size,
|
| 370 |
+
attention_mask=attention_mask,
|
| 371 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
sample = upsample_block(
|
| 375 |
+
hidden_states=sample,
|
| 376 |
+
temb=emb,
|
| 377 |
+
res_hidden_states_tuple=res_samples,
|
| 378 |
+
upsample_size=upsample_size,
|
| 379 |
+
scale=lora_scale,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# if i in [1, 4, 7]:
|
| 383 |
+
out_list.append(sample)
|
| 384 |
+
|
| 385 |
+
# 6. post-process
|
| 386 |
+
if self.conv_norm_out:
|
| 387 |
+
sample = self.conv_norm_out(sample)
|
| 388 |
+
sample = self.conv_act(sample)
|
| 389 |
+
sample = self.conv_out(sample)
|
| 390 |
+
|
| 391 |
+
if not return_dict:
|
| 392 |
+
return (sample,)
|
| 393 |
+
|
| 394 |
+
return UNet2DConditionOutput(sample=sample), out_list
|
| 395 |
+
|
| 396 |
+
self.forward = forward
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class AttentionControl(abc.ABC):
|
| 400 |
+
|
| 401 |
+
def step_callback(self, x_t):
|
| 402 |
+
return x_t
|
| 403 |
+
|
| 404 |
+
def between_steps(self):
|
| 405 |
+
return
|
| 406 |
+
|
| 407 |
+
@property
|
| 408 |
+
def num_uncond_att_layers(self):
|
| 409 |
+
return 0
|
| 410 |
+
|
| 411 |
+
@abc.abstractmethod
|
| 412 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 413 |
+
raise NotImplementedError
|
| 414 |
+
|
| 415 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
| 416 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
| 417 |
+
# self.forward(attn, is_cross, place_in_unet)
|
| 418 |
+
if LOW_RESOURCE:
|
| 419 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
| 420 |
+
else:
|
| 421 |
+
h = attn.shape[0]
|
| 422 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
| 423 |
+
self.cur_att_layer += 1
|
| 424 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
| 425 |
+
self.cur_att_layer = 0
|
| 426 |
+
self.cur_step += 1
|
| 427 |
+
self.between_steps()
|
| 428 |
+
return attn
|
| 429 |
+
|
| 430 |
+
def reset(self):
|
| 431 |
+
self.cur_step = 0
|
| 432 |
+
self.cur_att_layer = 0
|
| 433 |
+
|
| 434 |
+
def __init__(self):
|
| 435 |
+
self.cur_step = 0
|
| 436 |
+
self.num_att_layers = -1
|
| 437 |
+
self.cur_att_layer = 0
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class EmptyControl(AttentionControl):
|
| 441 |
+
|
| 442 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 443 |
+
return attn
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class AttentionStore(AttentionControl):
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
def get_empty_store():
|
| 450 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
| 451 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
| 452 |
+
|
| 453 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 454 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
| 455 |
+
if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead
|
| 456 |
+
self.step_store[key].append(attn)
|
| 457 |
+
return attn
|
| 458 |
+
|
| 459 |
+
def between_steps(self):
|
| 460 |
+
self.attention_store = self.step_store
|
| 461 |
+
if self.save_global_store:
|
| 462 |
+
with torch.no_grad():
|
| 463 |
+
if len(self.global_store) == 0:
|
| 464 |
+
self.global_store = self.step_store
|
| 465 |
+
else:
|
| 466 |
+
for key in self.global_store:
|
| 467 |
+
for i in range(len(self.global_store[key])):
|
| 468 |
+
self.global_store[key][i] += self.step_store[key][i].detach()
|
| 469 |
+
self.step_store = self.get_empty_store()
|
| 470 |
+
self.step_store = self.get_empty_store()
|
| 471 |
+
|
| 472 |
+
def get_average_attention(self):
|
| 473 |
+
average_attention = self.attention_store
|
| 474 |
+
return average_attention
|
| 475 |
+
|
| 476 |
+
def get_average_global_attention(self):
|
| 477 |
+
average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
|
| 478 |
+
self.attention_store}
|
| 479 |
+
return average_attention
|
| 480 |
+
|
| 481 |
+
def reset(self):
|
| 482 |
+
super(AttentionStore, self).reset()
|
| 483 |
+
self.step_store = self.get_empty_store()
|
| 484 |
+
self.attention_store = {}
|
| 485 |
+
self.global_store = {}
|
| 486 |
+
|
| 487 |
+
def __init__(self, max_size=32, save_global_store=False):
|
| 488 |
+
'''
|
| 489 |
+
Initialize an empty AttentionStore
|
| 490 |
+
:param step_index: used to visualize only a specific step in the diffusion process
|
| 491 |
+
'''
|
| 492 |
+
super(AttentionStore, self).__init__()
|
| 493 |
+
self.save_global_store = save_global_store
|
| 494 |
+
self.max_size = max_size
|
| 495 |
+
self.step_store = self.get_empty_store()
|
| 496 |
+
self.attention_store = {}
|
| 497 |
+
self.global_store = {}
|
| 498 |
+
self.curr_step_index = 0
|
| 499 |
+
|
| 500 |
+
def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
|
| 501 |
+
out = []
|
| 502 |
+
attention_maps = attention_store.get_average_attention()
|
| 503 |
+
num_pixels = res ** 2
|
| 504 |
+
for location in from_where:
|
| 505 |
+
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
|
| 506 |
+
if item.shape[1] == num_pixels:
|
| 507 |
+
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
|
| 508 |
+
out.append(cross_maps)
|
| 509 |
+
out = torch.cat(out, dim=0)
|
| 510 |
+
out = out.sum(0) / out.shape[0]
|
| 511 |
+
return out
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
|
| 515 |
+
tokens = tokenizer.encode(prompts[select])
|
| 516 |
+
decoder = tokenizer.decode
|
| 517 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
|
| 518 |
+
images = []
|
| 519 |
+
for i in range(len(tokens)):
|
| 520 |
+
image = attention_maps[:, :, i]
|
| 521 |
+
image = 255 * image / image.max()
|
| 522 |
+
image = image.unsqueeze(-1).expand(*image.shape, 3)
|
| 523 |
+
image = image.numpy().astype(np.uint8)
|
| 524 |
+
image = np.array(Image.fromarray(image).resize((256, 256)))
|
| 525 |
+
image = text_under_image(image, decoder(int(tokens[i])))
|
| 526 |
+
images.append(image)
|
| 527 |
+
view_images(np.stack(images, axis=0))
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
|
| 531 |
+
max_com=10, select: int = 0):
|
| 532 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
|
| 533 |
+
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
|
| 534 |
+
images = []
|
| 535 |
+
for i in range(max_com):
|
| 536 |
+
image = vh[i].reshape(res, res)
|
| 537 |
+
image = image - image.min()
|
| 538 |
+
image = 255 * image / image.max()
|
| 539 |
+
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
|
| 540 |
+
image = Image.fromarray(image).resize((256, 256))
|
| 541 |
+
image = np.array(image)
|
| 542 |
+
images.append(image)
|
| 543 |
+
view_images(np.concatenate(images, axis=1))
|
| 544 |
+
|
| 545 |
+
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
|
| 546 |
+
h, w, c = image.shape
|
| 547 |
+
offset = int(h * .2)
|
| 548 |
+
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
|
| 549 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 550 |
+
# font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
|
| 551 |
+
img[:h] = image
|
| 552 |
+
textsize = cv2.getTextSize(text, font, 1, 2)[0]
|
| 553 |
+
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
|
| 554 |
+
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
|
| 555 |
+
return img
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def view_images(images, num_rows=1, offset_ratio=0.02):
|
| 559 |
+
if type(images) is list:
|
| 560 |
+
num_empty = len(images) % num_rows
|
| 561 |
+
elif images.ndim == 4:
|
| 562 |
+
num_empty = images.shape[0] % num_rows
|
| 563 |
+
else:
|
| 564 |
+
images = [images]
|
| 565 |
+
num_empty = 0
|
| 566 |
+
|
| 567 |
+
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
|
| 568 |
+
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
|
| 569 |
+
num_items = len(images)
|
| 570 |
+
|
| 571 |
+
h, w, c = images[0].shape
|
| 572 |
+
offset = int(h * offset_ratio)
|
| 573 |
+
num_cols = num_items // num_rows
|
| 574 |
+
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
|
| 575 |
+
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
|
| 576 |
+
for i in range(num_rows):
|
| 577 |
+
for j in range(num_cols):
|
| 578 |
+
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
|
| 579 |
+
i * num_cols + j]
|
| 580 |
+
|
| 581 |
+
pil_img = Image.fromarray(image_)
|
| 582 |
+
display(pil_img)
|
| 583 |
+
|
| 584 |
+
def self_cross_attn(self_attn, cross_attn):
|
| 585 |
+
res = self_attn.shape[0]
|
| 586 |
+
assert res == cross_attn.shape[0]
|
| 587 |
+
# cross attn [res, res] -> [res*res]
|
| 588 |
+
cross_attn_ = cross_attn.reshape([res*res])
|
| 589 |
+
# self_attn [res, res, res*res]
|
| 590 |
+
self_cross_attn = cross_attn_ * self_attn
|
| 591 |
+
self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0)
|
| 592 |
+
return self_cross_attn
|
_utils/attn_utils_new.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from IPython.display import display
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing import Union, Tuple, List
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
import math
|
| 11 |
+
from torch import nn, einsum
|
| 12 |
+
from inspect import isfunction
|
| 13 |
+
from diffusers.utils import logging
|
| 14 |
+
try:
|
| 15 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
| 16 |
+
except:
|
| 17 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
| 18 |
+
try:
|
| 19 |
+
from diffusers.models.cross_attention import CrossAttention
|
| 20 |
+
except:
|
| 21 |
+
from diffusers.models.attention_processor import Attention as CrossAttention
|
| 22 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 23 |
+
MAX_NUM_WORDS = 77
|
| 24 |
+
LOW_RESOURCE = False
|
| 25 |
+
|
| 26 |
+
class CountingCrossAttnProcessor1:
|
| 27 |
+
|
| 28 |
+
def __init__(self, attnstore, place_in_unet):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.attnstore = attnstore
|
| 31 |
+
self.place_in_unet = place_in_unet
|
| 32 |
+
|
| 33 |
+
def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 34 |
+
batch_size, sequence_length, dim = hidden_states.shape
|
| 35 |
+
h = attn_layer.heads
|
| 36 |
+
q = attn_layer.to_q(hidden_states)
|
| 37 |
+
is_cross = encoder_hidden_states is not None
|
| 38 |
+
context = encoder_hidden_states if is_cross else hidden_states
|
| 39 |
+
k = attn_layer.to_k(context)
|
| 40 |
+
v = attn_layer.to_v(context)
|
| 41 |
+
# q = attn_layer.reshape_heads_to_batch_dim(q)
|
| 42 |
+
# k = attn_layer.reshape_heads_to_batch_dim(k)
|
| 43 |
+
# v = attn_layer.reshape_heads_to_batch_dim(v)
|
| 44 |
+
# q = attn_layer.head_to_batch_dim(q)
|
| 45 |
+
# k = attn_layer.head_to_batch_dim(k)
|
| 46 |
+
# v = attn_layer.head_to_batch_dim(v)
|
| 47 |
+
q = self.head_to_batch_dim(q, h)
|
| 48 |
+
k = self.head_to_batch_dim(k, h)
|
| 49 |
+
v = self.head_to_batch_dim(v, h)
|
| 50 |
+
|
| 51 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale
|
| 52 |
+
|
| 53 |
+
if attention_mask is not None:
|
| 54 |
+
attention_mask = attention_mask.reshape(batch_size, -1)
|
| 55 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 56 |
+
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
|
| 57 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
| 58 |
+
|
| 59 |
+
# attention, what we cannot get enough of
|
| 60 |
+
attn_ = sim.softmax(dim=-1).clone()
|
| 61 |
+
# softmax = nn.Softmax(dim=-1)
|
| 62 |
+
# attn_ = softmax(sim)
|
| 63 |
+
self.attnstore(attn_, is_cross, self.place_in_unet)
|
| 64 |
+
out = torch.einsum("b i j, b j d -> b i d", attn_, v)
|
| 65 |
+
# out = attn_layer.batch_to_head_dim(out)
|
| 66 |
+
out = self.batch_to_head_dim(out, h)
|
| 67 |
+
|
| 68 |
+
if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
|
| 69 |
+
to_out = attn_layer.to_out[0]
|
| 70 |
+
else:
|
| 71 |
+
to_out = attn_layer.to_out
|
| 72 |
+
|
| 73 |
+
out = to_out(out)
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
def batch_to_head_dim(self, tensor, head_size):
|
| 77 |
+
# head_size = self.heads
|
| 78 |
+
batch_size, seq_len, dim = tensor.shape
|
| 79 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 80 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
| 81 |
+
return tensor
|
| 82 |
+
|
| 83 |
+
def head_to_batch_dim(self, tensor, head_size, out_dim=3):
|
| 84 |
+
# head_size = self.heads
|
| 85 |
+
batch_size, seq_len, dim = tensor.shape
|
| 86 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
| 87 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
| 88 |
+
|
| 89 |
+
if out_dim == 3:
|
| 90 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
| 91 |
+
|
| 92 |
+
return tensor
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def register_attention_control(model, controller):
|
| 96 |
+
|
| 97 |
+
attn_procs = {}
|
| 98 |
+
cross_att_count = 0
|
| 99 |
+
for name in model.unet.attn_processors.keys():
|
| 100 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
|
| 101 |
+
if name.startswith("mid_block"):
|
| 102 |
+
hidden_size = model.unet.config.block_out_channels[-1]
|
| 103 |
+
place_in_unet = "mid"
|
| 104 |
+
elif name.startswith("up_blocks"):
|
| 105 |
+
block_id = int(name[len("up_blocks.")])
|
| 106 |
+
hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
|
| 107 |
+
place_in_unet = "up"
|
| 108 |
+
elif name.startswith("down_blocks"):
|
| 109 |
+
block_id = int(name[len("down_blocks.")])
|
| 110 |
+
hidden_size = model.unet.config.block_out_channels[block_id]
|
| 111 |
+
place_in_unet = "down"
|
| 112 |
+
else:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
cross_att_count += 1
|
| 116 |
+
# attn_procs[name] = AttendExciteCrossAttnProcessor(
|
| 117 |
+
# attnstore=controller, place_in_unet=place_in_unet
|
| 118 |
+
# )
|
| 119 |
+
attn_procs[name] = CountingCrossAttnProcessor1(
|
| 120 |
+
attnstore=controller, place_in_unet=place_in_unet
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
model.unet.set_attn_processor(attn_procs)
|
| 124 |
+
controller.num_att_layers = cross_att_count
|
| 125 |
+
|
| 126 |
+
def register_hier_output(model):
|
| 127 |
+
self = model.unet
|
| 128 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 129 |
+
def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None,
|
| 130 |
+
attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None,
|
| 131 |
+
mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True):
|
| 132 |
+
|
| 133 |
+
out_list = []
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 137 |
+
|
| 138 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 139 |
+
forward_upsample_size = False
|
| 140 |
+
upsample_size = None
|
| 141 |
+
|
| 142 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 143 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 144 |
+
forward_upsample_size = True
|
| 145 |
+
|
| 146 |
+
if attention_mask is not None:
|
| 147 |
+
# assume that mask is expressed as:
|
| 148 |
+
# (1 = keep, 0 = discard)
|
| 149 |
+
# convert mask into a bias that can be added to attention scores:
|
| 150 |
+
# (keep = +0, discard = -10000.0)
|
| 151 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 152 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 153 |
+
|
| 154 |
+
if encoder_attention_mask is not None:
|
| 155 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 156 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 157 |
+
|
| 158 |
+
if self.config.center_input_sample:
|
| 159 |
+
sample = 2 * sample - 1.0
|
| 160 |
+
|
| 161 |
+
timesteps = timestep
|
| 162 |
+
if not torch.is_tensor(timesteps):
|
| 163 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 164 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 165 |
+
is_mps = sample.device.type == "mps"
|
| 166 |
+
if isinstance(timestep, float):
|
| 167 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 168 |
+
else:
|
| 169 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 170 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 171 |
+
elif len(timesteps.shape) == 0:
|
| 172 |
+
timesteps = timesteps[None].to(sample.device)
|
| 173 |
+
|
| 174 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 175 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 176 |
+
|
| 177 |
+
t_emb = self.time_proj(timesteps)
|
| 178 |
+
|
| 179 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 180 |
+
|
| 181 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 182 |
+
aug_emb = None
|
| 183 |
+
|
| 184 |
+
if self.class_embedding is not None:
|
| 185 |
+
if class_labels is None:
|
| 186 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 187 |
+
|
| 188 |
+
if self.config.class_embed_type == "timestep":
|
| 189 |
+
class_labels = self.time_proj(class_labels)
|
| 190 |
+
|
| 191 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 192 |
+
# there might be better ways to encapsulate this.
|
| 193 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
| 194 |
+
|
| 195 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
| 196 |
+
|
| 197 |
+
if self.config.class_embeddings_concat:
|
| 198 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
| 199 |
+
else:
|
| 200 |
+
emb = emb + class_emb
|
| 201 |
+
|
| 202 |
+
if self.config.addition_embed_type == "text":
|
| 203 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 204 |
+
elif self.config.addition_embed_type == "text_image":
|
| 205 |
+
# Kandinsky 2.1 - style
|
| 206 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 212 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
| 213 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
| 214 |
+
elif self.config.addition_embed_type == "text_time":
|
| 215 |
+
# SDXL - style
|
| 216 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 217 |
+
raise ValueError(
|
| 218 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 219 |
+
)
|
| 220 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 221 |
+
if "time_ids" not in added_cond_kwargs:
|
| 222 |
+
raise ValueError(
|
| 223 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 224 |
+
)
|
| 225 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 226 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 227 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 228 |
+
|
| 229 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 230 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 231 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 232 |
+
elif self.config.addition_embed_type == "image":
|
| 233 |
+
# Kandinsky 2.2 - style
|
| 234 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 235 |
+
raise ValueError(
|
| 236 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 237 |
+
)
|
| 238 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 239 |
+
aug_emb = self.add_embedding(image_embs)
|
| 240 |
+
elif self.config.addition_embed_type == "image_hint":
|
| 241 |
+
# Kandinsky 2.2 - style
|
| 242 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
| 245 |
+
)
|
| 246 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 247 |
+
hint = added_cond_kwargs.get("hint")
|
| 248 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
| 249 |
+
sample = torch.cat([sample, hint], dim=1)
|
| 250 |
+
|
| 251 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 252 |
+
|
| 253 |
+
if self.time_embed_act is not None:
|
| 254 |
+
emb = self.time_embed_act(emb)
|
| 255 |
+
|
| 256 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
| 257 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
| 258 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
| 259 |
+
# Kadinsky 2.1 - style
|
| 260 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 261 |
+
raise ValueError(
|
| 262 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 266 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
| 267 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
| 268 |
+
# Kandinsky 2.2 - style
|
| 269 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 272 |
+
)
|
| 273 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 274 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
| 275 |
+
# 2. pre-process
|
| 276 |
+
sample = self.conv_in(sample) # 1, 320, 64, 64
|
| 277 |
+
|
| 278 |
+
# 2.5 GLIGEN position net
|
| 279 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
| 280 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
| 281 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
| 282 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
| 283 |
+
|
| 284 |
+
# 3. down
|
| 285 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
| 286 |
+
|
| 287 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
| 288 |
+
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
| 289 |
+
|
| 290 |
+
down_block_res_samples = (sample,)
|
| 291 |
+
|
| 292 |
+
for downsample_block in self.down_blocks:
|
| 293 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 294 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
| 295 |
+
additional_residuals = {}
|
| 296 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
| 297 |
+
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
| 298 |
+
|
| 299 |
+
sample, res_samples = downsample_block(
|
| 300 |
+
hidden_states=sample,
|
| 301 |
+
temb=emb,
|
| 302 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 303 |
+
attention_mask=attention_mask,
|
| 304 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 305 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 306 |
+
**additional_residuals,
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
| 310 |
+
|
| 311 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
| 312 |
+
sample += down_block_additional_residuals.pop(0)
|
| 313 |
+
|
| 314 |
+
down_block_res_samples += res_samples
|
| 315 |
+
|
| 316 |
+
if is_controlnet:
|
| 317 |
+
new_down_block_res_samples = ()
|
| 318 |
+
|
| 319 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 320 |
+
down_block_res_samples, down_block_additional_residuals
|
| 321 |
+
):
|
| 322 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 323 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
| 324 |
+
|
| 325 |
+
down_block_res_samples = new_down_block_res_samples
|
| 326 |
+
|
| 327 |
+
# 4. mid
|
| 328 |
+
if self.mid_block is not None:
|
| 329 |
+
sample = self.mid_block(
|
| 330 |
+
sample,
|
| 331 |
+
emb,
|
| 332 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 333 |
+
attention_mask=attention_mask,
|
| 334 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 335 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 336 |
+
)
|
| 337 |
+
# To support T2I-Adapter-XL
|
| 338 |
+
if (
|
| 339 |
+
is_adapter
|
| 340 |
+
and len(down_block_additional_residuals) > 0
|
| 341 |
+
and sample.shape == down_block_additional_residuals[0].shape
|
| 342 |
+
):
|
| 343 |
+
sample += down_block_additional_residuals.pop(0)
|
| 344 |
+
|
| 345 |
+
if is_controlnet:
|
| 346 |
+
sample = sample + mid_block_additional_residual
|
| 347 |
+
|
| 348 |
+
# 5. up
|
| 349 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 350 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 351 |
+
|
| 352 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 353 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 354 |
+
|
| 355 |
+
# if we have not reached the final block and need to forward the
|
| 356 |
+
# upsample size, we do it here
|
| 357 |
+
if not is_final_block and forward_upsample_size:
|
| 358 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 359 |
+
|
| 360 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 361 |
+
sample = upsample_block(
|
| 362 |
+
hidden_states=sample,
|
| 363 |
+
temb=emb,
|
| 364 |
+
res_hidden_states_tuple=res_samples,
|
| 365 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 366 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 367 |
+
upsample_size=upsample_size,
|
| 368 |
+
attention_mask=attention_mask,
|
| 369 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 370 |
+
)
|
| 371 |
+
else:
|
| 372 |
+
sample = upsample_block(
|
| 373 |
+
hidden_states=sample,
|
| 374 |
+
temb=emb,
|
| 375 |
+
res_hidden_states_tuple=res_samples,
|
| 376 |
+
upsample_size=upsample_size,
|
| 377 |
+
scale=lora_scale,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
out_list.append(sample)
|
| 381 |
+
|
| 382 |
+
# 6. post-process
|
| 383 |
+
if self.conv_norm_out:
|
| 384 |
+
sample = self.conv_norm_out(sample)
|
| 385 |
+
sample = self.conv_act(sample)
|
| 386 |
+
sample = self.conv_out(sample)
|
| 387 |
+
|
| 388 |
+
if not return_dict:
|
| 389 |
+
return (sample,)
|
| 390 |
+
|
| 391 |
+
return UNet2DConditionOutput(sample=sample), out_list
|
| 392 |
+
|
| 393 |
+
self.forward = forward
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class AttentionControl(abc.ABC):
|
| 402 |
+
|
| 403 |
+
def step_callback(self, x_t):
|
| 404 |
+
return x_t
|
| 405 |
+
|
| 406 |
+
def between_steps(self):
|
| 407 |
+
return
|
| 408 |
+
|
| 409 |
+
@property
|
| 410 |
+
def num_uncond_att_layers(self):
|
| 411 |
+
return 0
|
| 412 |
+
|
| 413 |
+
@abc.abstractmethod
|
| 414 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 415 |
+
raise NotImplementedError
|
| 416 |
+
|
| 417 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
| 418 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
| 419 |
+
# self.forward(attn, is_cross, place_in_unet)
|
| 420 |
+
if LOW_RESOURCE:
|
| 421 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
| 422 |
+
else:
|
| 423 |
+
h = attn.shape[0]
|
| 424 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
| 425 |
+
self.cur_att_layer += 1
|
| 426 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
| 427 |
+
self.cur_att_layer = 0
|
| 428 |
+
self.cur_step += 1
|
| 429 |
+
self.between_steps()
|
| 430 |
+
return attn
|
| 431 |
+
|
| 432 |
+
def reset(self):
|
| 433 |
+
self.cur_step = 0
|
| 434 |
+
self.cur_att_layer = 0
|
| 435 |
+
|
| 436 |
+
def __init__(self):
|
| 437 |
+
self.cur_step = 0
|
| 438 |
+
self.num_att_layers = -1
|
| 439 |
+
self.cur_att_layer = 0
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class EmptyControl(AttentionControl):
|
| 443 |
+
|
| 444 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 445 |
+
return attn
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class AttentionStore(AttentionControl):
|
| 449 |
+
|
| 450 |
+
@staticmethod
|
| 451 |
+
def get_empty_store():
|
| 452 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
| 453 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
| 454 |
+
|
| 455 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 456 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
| 457 |
+
if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead
|
| 458 |
+
self.step_store[key].append(attn)
|
| 459 |
+
return attn
|
| 460 |
+
|
| 461 |
+
def between_steps(self):
|
| 462 |
+
self.attention_store = self.step_store
|
| 463 |
+
if self.save_global_store:
|
| 464 |
+
with torch.no_grad():
|
| 465 |
+
if len(self.global_store) == 0:
|
| 466 |
+
self.global_store = self.step_store
|
| 467 |
+
else:
|
| 468 |
+
for key in self.global_store:
|
| 469 |
+
for i in range(len(self.global_store[key])):
|
| 470 |
+
self.global_store[key][i] += self.step_store[key][i].detach()
|
| 471 |
+
self.step_store = self.get_empty_store()
|
| 472 |
+
self.step_store = self.get_empty_store()
|
| 473 |
+
|
| 474 |
+
def get_average_attention(self):
|
| 475 |
+
average_attention = self.attention_store
|
| 476 |
+
return average_attention
|
| 477 |
+
|
| 478 |
+
def get_average_global_attention(self):
|
| 479 |
+
average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
|
| 480 |
+
self.attention_store}
|
| 481 |
+
return average_attention
|
| 482 |
+
|
| 483 |
+
def reset(self):
|
| 484 |
+
super(AttentionStore, self).reset()
|
| 485 |
+
self.step_store = self.get_empty_store()
|
| 486 |
+
self.attention_store = {}
|
| 487 |
+
self.global_store = {}
|
| 488 |
+
|
| 489 |
+
def __init__(self, max_size=32, save_global_store=False):
|
| 490 |
+
'''
|
| 491 |
+
Initialize an empty AttentionStore
|
| 492 |
+
:param step_index: used to visualize only a specific step in the diffusion process
|
| 493 |
+
'''
|
| 494 |
+
super(AttentionStore, self).__init__()
|
| 495 |
+
self.save_global_store = save_global_store
|
| 496 |
+
self.max_size = max_size
|
| 497 |
+
self.step_store = self.get_empty_store()
|
| 498 |
+
self.attention_store = {}
|
| 499 |
+
self.global_store = {}
|
| 500 |
+
self.curr_step_index = 0
|
| 501 |
+
|
| 502 |
+
def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
|
| 503 |
+
out = []
|
| 504 |
+
attention_maps = attention_store.get_average_attention()
|
| 505 |
+
num_pixels = res ** 2
|
| 506 |
+
for location in from_where:
|
| 507 |
+
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
|
| 508 |
+
if item.shape[1] == num_pixels:
|
| 509 |
+
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
|
| 510 |
+
out.append(cross_maps)
|
| 511 |
+
out = torch.cat(out, dim=0)
|
| 512 |
+
out = out.sum(0) / out.shape[0]
|
| 513 |
+
return out
|
| 514 |
+
|
| 515 |
+
def aggregate_attention1(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
|
| 516 |
+
out = []
|
| 517 |
+
attention_maps = attention_store.get_average_attention()
|
| 518 |
+
num_pixels = res ** 2
|
| 519 |
+
for location in from_where:
|
| 520 |
+
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
|
| 521 |
+
if item.shape[1] == num_pixels:
|
| 522 |
+
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
|
| 523 |
+
out.append(cross_maps)
|
| 524 |
+
# out = torch.cat(out, dim=0)
|
| 525 |
+
# out = out.sum(0) / out.shape[0]
|
| 526 |
+
out = out[1]
|
| 527 |
+
out = out.sum(0) / out.shape[0]
|
| 528 |
+
return out
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
|
| 532 |
+
tokens = tokenizer.encode(prompts[select])
|
| 533 |
+
decoder = tokenizer.decode
|
| 534 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
|
| 535 |
+
images = []
|
| 536 |
+
for i in range(len(tokens)):
|
| 537 |
+
image = attention_maps[:, :, i]
|
| 538 |
+
image = 255 * image / image.max()
|
| 539 |
+
image = image.unsqueeze(-1).expand(*image.shape, 3)
|
| 540 |
+
image = image.numpy().astype(np.uint8)
|
| 541 |
+
image = np.array(Image.fromarray(image).resize((256, 256)))
|
| 542 |
+
image = text_under_image(image, decoder(int(tokens[i])))
|
| 543 |
+
images.append(image)
|
| 544 |
+
view_images(np.stack(images, axis=0))
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
|
| 548 |
+
max_com=10, select: int = 0):
|
| 549 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
|
| 550 |
+
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
|
| 551 |
+
images = []
|
| 552 |
+
for i in range(max_com):
|
| 553 |
+
image = vh[i].reshape(res, res)
|
| 554 |
+
image = image - image.min()
|
| 555 |
+
image = 255 * image / image.max()
|
| 556 |
+
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
|
| 557 |
+
image = Image.fromarray(image).resize((256, 256))
|
| 558 |
+
image = np.array(image)
|
| 559 |
+
images.append(image)
|
| 560 |
+
view_images(np.concatenate(images, axis=1))
|
| 561 |
+
|
| 562 |
+
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
|
| 563 |
+
h, w, c = image.shape
|
| 564 |
+
offset = int(h * .2)
|
| 565 |
+
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
|
| 566 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 567 |
+
# font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
|
| 568 |
+
img[:h] = image
|
| 569 |
+
textsize = cv2.getTextSize(text, font, 1, 2)[0]
|
| 570 |
+
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
|
| 571 |
+
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
|
| 572 |
+
return img
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def view_images(images, num_rows=1, offset_ratio=0.02):
|
| 576 |
+
if type(images) is list:
|
| 577 |
+
num_empty = len(images) % num_rows
|
| 578 |
+
elif images.ndim == 4:
|
| 579 |
+
num_empty = images.shape[0] % num_rows
|
| 580 |
+
else:
|
| 581 |
+
images = [images]
|
| 582 |
+
num_empty = 0
|
| 583 |
+
|
| 584 |
+
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
|
| 585 |
+
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
|
| 586 |
+
num_items = len(images)
|
| 587 |
+
|
| 588 |
+
h, w, c = images[0].shape
|
| 589 |
+
offset = int(h * offset_ratio)
|
| 590 |
+
num_cols = num_items // num_rows
|
| 591 |
+
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
|
| 592 |
+
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
|
| 593 |
+
for i in range(num_rows):
|
| 594 |
+
for j in range(num_cols):
|
| 595 |
+
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
|
| 596 |
+
i * num_cols + j]
|
| 597 |
+
|
| 598 |
+
pil_img = Image.fromarray(image_)
|
| 599 |
+
display(pil_img)
|
| 600 |
+
|
| 601 |
+
def self_cross_attn(self_attn, cross_attn):
|
| 602 |
+
cross_attn = cross_attn.squeeze()
|
| 603 |
+
res = self_attn.shape[0]
|
| 604 |
+
assert res == cross_attn.shape[-1]
|
| 605 |
+
# cross attn [res, res] -> [res*res]
|
| 606 |
+
cross_attn_ = cross_attn.reshape([res*res])
|
| 607 |
+
# self_attn [res, res, res*res]
|
| 608 |
+
self_cross_attn = cross_attn_ * self_attn
|
| 609 |
+
self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0)
|
| 610 |
+
return self_cross_attn
|
_utils/config.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
attn_dist_mode: v0
|
| 2 |
+
attn_positional_bias: rope
|
| 3 |
+
attn_positional_bias_n_spatial: 16
|
| 4 |
+
causal_norm: quiet_softmax
|
| 5 |
+
coord_dim: 2
|
| 6 |
+
d_model: 320
|
| 7 |
+
dropout: 0.0
|
| 8 |
+
feat_dim: 7
|
| 9 |
+
feat_embed_per_dim: 8
|
| 10 |
+
nhead: 4
|
| 11 |
+
num_decoder_layers: 6
|
| 12 |
+
num_encoder_layers: 6
|
| 13 |
+
pos_embed_per_dim: 32
|
| 14 |
+
spatial_pos_cutoff: 256
|
| 15 |
+
window: 4
|
_utils/example_config.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size: 1
|
| 2 |
+
crop_size:
|
| 3 |
+
- 256
|
| 4 |
+
- 256
|
| 5 |
+
detection_folders:
|
| 6 |
+
- TRA
|
| 7 |
+
dropout: 0.01
|
| 8 |
+
example_images: False # Slow
|
| 9 |
+
input_train:
|
| 10 |
+
- data/ctc/Fluo-N2DL-HeLa/01
|
| 11 |
+
input_val:
|
| 12 |
+
- data/ctc/Fluo-N2DL-HeLa/02
|
| 13 |
+
max_tokens: 2048
|
| 14 |
+
name: example
|
| 15 |
+
ndim: 2
|
| 16 |
+
num_decoder_layers: 5
|
| 17 |
+
num_encoder_layers: 5
|
| 18 |
+
outdir: runs
|
| 19 |
+
distributed: False
|
| 20 |
+
window: 4
|
_utils/load_models.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import RunConfig
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
def load_stable_diffusion_model(config: RunConfig):
|
| 7 |
+
device = torch.device('cpu')
|
| 8 |
+
|
| 9 |
+
if config.sd_2_1:
|
| 10 |
+
stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base"
|
| 11 |
+
else:
|
| 12 |
+
stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
|
| 13 |
+
# stable = StableCountingPipeline.from_pretrained(stable_diffusion_version).to(device)
|
| 14 |
+
stable = StableDiffusionPipeline.from_pretrained(stable_diffusion_version).to(device)
|
| 15 |
+
return stable
|
| 16 |
+
|
_utils/load_track_data.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from glob import glob
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from natsort import natsorted
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tifffile
|
| 8 |
+
import skimage.io as io
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
import cv2
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from models.tra_post_model.trackastra.utils import normalize_01, normalize
|
| 13 |
+
IMG_SIZE = 512
|
| 14 |
+
|
| 15 |
+
def _load_tiffs(folder: Path, dtype=None):
|
| 16 |
+
"""Load a sequence of tiff files from a folder into a 3D numpy array."""
|
| 17 |
+
images = glob(str(folder / "*.tif"))
|
| 18 |
+
test_data = tifffile.imread(images[0])
|
| 19 |
+
if len(test_data.shape) == 3:
|
| 20 |
+
turn_gray = True
|
| 21 |
+
else:
|
| 22 |
+
turn_gray = False
|
| 23 |
+
end_frame = len(images)
|
| 24 |
+
if not turn_gray:
|
| 25 |
+
x = np.stack([
|
| 26 |
+
tifffile.imread(f).astype(dtype)
|
| 27 |
+
for f in tqdm(
|
| 28 |
+
sorted(folder.glob("*.tif"))[0 : end_frame : 1],
|
| 29 |
+
leave=False,
|
| 30 |
+
desc=f"Loading [0:{end_frame}]",
|
| 31 |
+
)
|
| 32 |
+
])
|
| 33 |
+
else:
|
| 34 |
+
x = []
|
| 35 |
+
for f in tqdm(
|
| 36 |
+
sorted(folder.glob("*.tif"))[0 : end_frame : 1],
|
| 37 |
+
leave=False,
|
| 38 |
+
desc=f"Loading [0:{end_frame}]",
|
| 39 |
+
):
|
| 40 |
+
img = tifffile.imread(f).astype(dtype)
|
| 41 |
+
if img.ndim == 3:
|
| 42 |
+
if img.shape[-1] > 3:
|
| 43 |
+
img = img[..., :3]
|
| 44 |
+
img = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2])
|
| 45 |
+
x.append(img)
|
| 46 |
+
x = np.stack(x)
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_track_images(file_dir):
|
| 51 |
+
|
| 52 |
+
# suffix_ = [".png", ".tif", ".tiff", ".jpg"]
|
| 53 |
+
assert len(glob(file_dir + "/*.tif")) > 0, f"No tif images found in {file_dir}"
|
| 54 |
+
images = natsorted(glob(file_dir + "/*.tif"))
|
| 55 |
+
imgs = []
|
| 56 |
+
imgs_raw = []
|
| 57 |
+
images_stable = []
|
| 58 |
+
# load images for seg and track
|
| 59 |
+
for img_path in tqdm(images, desc="Loading images"):
|
| 60 |
+
img = tifffile.imread(img_path)
|
| 61 |
+
img_raw = io.imread(img_path)
|
| 62 |
+
|
| 63 |
+
if img.dtype == 'uint16':
|
| 64 |
+
img = ((img - img.min()) / (img.max() - img.min() + 1e-6) * 255).astype(np.uint8)
|
| 65 |
+
img = np.stack([img] * 3, axis=-1)
|
| 66 |
+
w, h = img.shape[1], img.shape[0]
|
| 67 |
+
else:
|
| 68 |
+
img = Image.open(img_path).convert("RGB")
|
| 69 |
+
w, h = img.size
|
| 70 |
+
|
| 71 |
+
img = T.Compose([
|
| 72 |
+
T.ToTensor(),
|
| 73 |
+
T.Resize((IMG_SIZE, IMG_SIZE)),
|
| 74 |
+
])(img)
|
| 75 |
+
|
| 76 |
+
image_stable = img - 0.5
|
| 77 |
+
img = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
imgs.append(img)
|
| 81 |
+
imgs_raw.append(img_raw)
|
| 82 |
+
images_stable.append(image_stable)
|
| 83 |
+
|
| 84 |
+
height = h
|
| 85 |
+
width = w
|
| 86 |
+
imgs = np.stack(imgs, axis=0)
|
| 87 |
+
imgs_raw = np.stack(imgs_raw, axis=0)
|
| 88 |
+
images_stable = np.stack(images_stable, axis=0)
|
| 89 |
+
|
| 90 |
+
# track data
|
| 91 |
+
imgs_ = _load_tiffs(Path(file_dir), dtype=np.float32)
|
| 92 |
+
imgs_01 = np.stack([
|
| 93 |
+
normalize_01(_x) for _x in tqdm(imgs_, desc="Normalizing", leave=False)
|
| 94 |
+
])
|
| 95 |
+
imgs_ = np.stack([
|
| 96 |
+
normalize(_x) for _x in tqdm(imgs_, desc="Normalizing", leave=False)
|
| 97 |
+
])
|
| 98 |
+
|
| 99 |
+
return imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width
|
| 100 |
+
|
| 101 |
+
if __name__ == "__main__":
|
| 102 |
+
file_dir = "data/2D+Time/DIC-C2DH-HeLa/train/DIC-C2DH-HeLa/02"
|
| 103 |
+
imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width = load_track_images(file_dir)
|
| 104 |
+
print(imgs.shape, imgs_raw.shape, images_stable.shape, imgs_.shape, imgs_01.shape, height, width)
|
_utils/misc_helper.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import shutil
|
| 5 |
+
from collections.abc import Mapping
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def basicConfig(*args, **kwargs):
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# To prevent duplicate logs, we mask this baseConfig setting
|
| 18 |
+
logging.basicConfig = basicConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def create_logger(name, log_file, level=logging.INFO):
|
| 22 |
+
log = logging.getLogger(name)
|
| 23 |
+
formatter = logging.Formatter(
|
| 24 |
+
"[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s"
|
| 25 |
+
)
|
| 26 |
+
fh = logging.FileHandler(log_file)
|
| 27 |
+
fh.setFormatter(formatter)
|
| 28 |
+
sh = logging.StreamHandler()
|
| 29 |
+
sh.setFormatter(formatter)
|
| 30 |
+
log.setLevel(level)
|
| 31 |
+
log.addHandler(fh)
|
| 32 |
+
log.addHandler(sh)
|
| 33 |
+
return log
|
| 34 |
+
|
| 35 |
+
def get_current_time():
|
| 36 |
+
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 37 |
+
return current_time
|
_utils/seg_eval.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def iou_torch(inst1, inst2):
|
| 5 |
+
inter = torch.logical_and(inst1, inst2).sum().float()
|
| 6 |
+
union = torch.logical_or(inst1, inst2).sum().float()
|
| 7 |
+
if union == 0:
|
| 8 |
+
return torch.tensor(float('nan'))
|
| 9 |
+
return inter / union
|
| 10 |
+
|
| 11 |
+
def get_instances_torch(mask):
|
| 12 |
+
# 返回所有非背景的 instance mask(布尔型)
|
| 13 |
+
ids = torch.unique(mask)
|
| 14 |
+
return [(mask == i) for i in ids if i != 0]
|
| 15 |
+
|
| 16 |
+
def compute_instance_miou(pred_mask, gt_mask):
|
| 17 |
+
# pred_mask 和 gt_mask 都是 torch.Tensor, shape [H, W], 整数类型
|
| 18 |
+
pred_instances = get_instances_torch(pred_mask)
|
| 19 |
+
gt_instances = get_instances_torch(gt_mask)
|
| 20 |
+
|
| 21 |
+
ious = []
|
| 22 |
+
for gt in gt_instances:
|
| 23 |
+
best_iou = torch.tensor(0.0).to(pred_mask.device)
|
| 24 |
+
for pred in pred_instances:
|
| 25 |
+
i = iou_torch(pred, gt)
|
| 26 |
+
if i > best_iou:
|
| 27 |
+
best_iou = i
|
| 28 |
+
ious.append(best_iou)
|
| 29 |
+
|
| 30 |
+
# 处理空情况
|
| 31 |
+
if len(ious) == 0:
|
| 32 |
+
return torch.tensor(float('nan'))
|
| 33 |
+
return torch.nanmean(torch.stack(ious))
|
| 34 |
+
|
| 35 |
+
from torch import Tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
|
| 39 |
+
# Average of Dice coefficient for all batches, or for a single mask
|
| 40 |
+
assert input.size() == target.size()
|
| 41 |
+
assert input.dim() == 3 or not reduce_batch_first
|
| 42 |
+
|
| 43 |
+
sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
|
| 44 |
+
|
| 45 |
+
inter = 2 * (input * target).sum(dim=sum_dim)
|
| 46 |
+
sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
|
| 47 |
+
sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
|
| 48 |
+
|
| 49 |
+
dice = (inter + epsilon) / (sets_sum + epsilon)
|
| 50 |
+
return dice.mean()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
|
| 54 |
+
# Average of Dice coefficient for all classes
|
| 55 |
+
return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
|
| 59 |
+
# Dice loss (objective to minimize) between 0 and 1
|
| 60 |
+
fn = multiclass_dice_coeff if multiclass else dice_coeff
|
| 61 |
+
return 1 - fn(input, target, reduce_batch_first=True)
|
_utils/track_args.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import configargparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def parse_train_args():
|
| 5 |
+
parser = configargparse.ArgumentParser(
|
| 6 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
| 7 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
| 8 |
+
allow_abbrev=False,
|
| 9 |
+
)
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"-c",
|
| 12 |
+
"--config",
|
| 13 |
+
default="_utils/example_config.yaml",
|
| 14 |
+
is_config_file=True,
|
| 15 |
+
help="config file path",
|
| 16 |
+
)
|
| 17 |
+
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
|
| 18 |
+
parser.add_argument("-o", "--outdir", type=str, default="runs")
|
| 19 |
+
parser.add_argument("--name", type=str, help="Name to append to timestamp")
|
| 20 |
+
parser.add_argument("--timestamp", type=bool, default=True)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"-m",
|
| 23 |
+
"--model",
|
| 24 |
+
type=str,
|
| 25 |
+
default="",
|
| 26 |
+
help="load this model at start (e.g. to continue training)",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--ndim", type=int, default=2, help="number of spatial dimensions"
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument("-d", "--d_model", type=int, default=256)
|
| 32 |
+
parser.add_argument("-w", "--window", type=int, default=10)
|
| 33 |
+
parser.add_argument("--epochs", type=int, default=100)
|
| 34 |
+
parser.add_argument("--warmup_epochs", type=int, default=10)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--detection_folders",
|
| 37 |
+
type=str,
|
| 38 |
+
nargs="+",
|
| 39 |
+
default=["TRA"],
|
| 40 |
+
help=(
|
| 41 |
+
"Subfolders to search for detections. Defaults to `TRA`, which corresponds"
|
| 42 |
+
" to using only the GT."
|
| 43 |
+
),
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument("--downscale_temporal", type=int, default=1)
|
| 46 |
+
parser.add_argument("--downscale_spatial", type=int, default=1)
|
| 47 |
+
parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
|
| 48 |
+
parser.add_argument("--from_subfolder", action="store_true")
|
| 49 |
+
# parser.add_argument("--train_samples", type=int, default=50000)
|
| 50 |
+
parser.add_argument("--num_encoder_layers", type=int, default=6)
|
| 51 |
+
parser.add_argument("--num_decoder_layers", type=int, default=6)
|
| 52 |
+
parser.add_argument("--pos_embed_per_dim", type=int, default=32)
|
| 53 |
+
parser.add_argument("--feat_embed_per_dim", type=int, default=8)
|
| 54 |
+
parser.add_argument("--dropout", type=float, default=0.00)
|
| 55 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
| 56 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
| 57 |
+
parser.add_argument("--max_tokens", type=int, default=None)
|
| 58 |
+
parser.add_argument("--delta_cutoff", type=int, default=2)
|
| 59 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--attn_positional_bias",
|
| 62 |
+
type=str,
|
| 63 |
+
choices=["rope", "bias", "none"],
|
| 64 |
+
default="rope",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
|
| 67 |
+
parser.add_argument("--attn_dist_mode", default="v0")
|
| 68 |
+
parser.add_argument("--mixedp", type=bool, default=True)
|
| 69 |
+
parser.add_argument("--dry", action="store_true")
|
| 70 |
+
parser.add_argument("--profile", action="store_true")
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--features",
|
| 73 |
+
type=str,
|
| 74 |
+
choices=[
|
| 75 |
+
"none",
|
| 76 |
+
"regionprops",
|
| 77 |
+
"regionprops2",
|
| 78 |
+
"patch",
|
| 79 |
+
"patch_regionprops",
|
| 80 |
+
"wrfeat",
|
| 81 |
+
],
|
| 82 |
+
default="wrfeat",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--causal_norm",
|
| 86 |
+
type=str,
|
| 87 |
+
choices=["none", "linear", "softmax", "quiet_softmax"],
|
| 88 |
+
default="quiet_softmax",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument("--div_upweight", type=float, default=2)
|
| 91 |
+
|
| 92 |
+
parser.add_argument("--augment", type=int, default=3)
|
| 93 |
+
parser.add_argument("--tracking_frequency", type=int, default=-1)
|
| 94 |
+
|
| 95 |
+
parser.add_argument("--sanity_dist", action="store_true")
|
| 96 |
+
parser.add_argument("--preallocate", type=bool, default=False)
|
| 97 |
+
parser.add_argument("--only_prechecks", action="store_true")
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--compress", type=bool, default=True, help="compress dataset"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
parser.add_argument("--seed", type=int, default=None)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--logger",
|
| 106 |
+
type=str,
|
| 107 |
+
default="tensorboard",
|
| 108 |
+
choices=["tensorboard", "wandb", "none"],
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument("--wandb_project", type=str, default="trackastra")
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--crop_size",
|
| 113 |
+
type=int,
|
| 114 |
+
# required=True,
|
| 115 |
+
nargs="+",
|
| 116 |
+
default=None,
|
| 117 |
+
help="random crop size for augmentation",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--weight_by_ndivs",
|
| 121 |
+
type=bool,
|
| 122 |
+
default=True,
|
| 123 |
+
help="Oversample windows that contain divisions",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--weight_by_dataset",
|
| 127 |
+
type=bool,
|
| 128 |
+
default=False,
|
| 129 |
+
help=(
|
| 130 |
+
"Inversely weight datasets by number of samples (to counter dataset size"
|
| 131 |
+
" imbalance)"
|
| 132 |
+
),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
args, unknown_args = parser.parse_known_args()
|
| 136 |
+
|
| 137 |
+
# # Hack to allow for --input_test
|
| 138 |
+
# allowed_unknown = ["input_test"]
|
| 139 |
+
# if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset(
|
| 140 |
+
# set(allowed_unknown)
|
| 141 |
+
# ):
|
| 142 |
+
# raise ValueError(f"Unknown args: {unknown_args}")
|
| 143 |
+
|
| 144 |
+
# pprint(vars(args))
|
| 145 |
+
|
| 146 |
+
# for backward compatibility
|
| 147 |
+
# if args.attn_positional_bias == "True":
|
| 148 |
+
# args.attn_positional_bias = "bias"
|
| 149 |
+
# elif args.attn_positional_bias == "False":
|
| 150 |
+
# args.attn_positional_bias = False
|
| 151 |
+
|
| 152 |
+
# if args.train_samples == 0:
|
| 153 |
+
# raise NotImplementedError(
|
| 154 |
+
# "--train_samples must be > 0, full dataset pass not supported."
|
| 155 |
+
# )
|
| 156 |
+
|
| 157 |
+
return args
|
config.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class RunConfig:
|
| 8 |
+
# Guiding text prompt
|
| 9 |
+
prompt: str = "<task-prompt>"
|
| 10 |
+
# Whether to use Stable Diffusion v2.1
|
| 11 |
+
sd_2_1: bool = False
|
| 12 |
+
# Which token indices to alter with attend-and-excite
|
| 13 |
+
token_indices: List[int] = field(default_factory=lambda: [2,5])
|
| 14 |
+
# Which random seeds to use when generating
|
| 15 |
+
seeds: List[int] = field(default_factory=lambda: [42])
|
| 16 |
+
# Path to save all outputs to
|
| 17 |
+
output_path: Path = Path('./outputs')
|
| 18 |
+
# Number of denoising steps
|
| 19 |
+
n_inference_steps: int = 50
|
| 20 |
+
# Text guidance scale
|
| 21 |
+
guidance_scale: float = 7.5
|
| 22 |
+
# Number of denoising steps to apply attend-and-excite
|
| 23 |
+
max_iter_to_alter: int = 25
|
| 24 |
+
# Resolution of UNet to compute attention maps over
|
| 25 |
+
attention_res: int = 16
|
| 26 |
+
# Whether to run standard SD or attend-and-excite
|
| 27 |
+
run_standard_sd: bool = False
|
| 28 |
+
# Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in
|
| 29 |
+
thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8})
|
| 30 |
+
# Scale factor for updating the denoised latent z_t
|
| 31 |
+
scale_factor: int = 20
|
| 32 |
+
# Start and end values used for scaling the scale factor - decays linearly with the denoising timestep
|
| 33 |
+
scale_range: tuple = field(default_factory=lambda: (1.0, 0.5))
|
| 34 |
+
# Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token
|
| 35 |
+
smooth_attentions: bool = True
|
| 36 |
+
# Standard deviation for the Gaussian smoothing
|
| 37 |
+
sigma: float = 0.5
|
| 38 |
+
# Kernel size for the Gaussian smoothing
|
| 39 |
+
kernel_size: int = 3
|
| 40 |
+
# Whether to save cross attention maps for the final results
|
| 41 |
+
save_cross_attention_maps: bool = False
|
| 42 |
+
|
| 43 |
+
def __post_init__(self):
|
| 44 |
+
self.output_path.mkdir(exist_ok=True, parents=True)
|
counting.py
ADDED
|
@@ -0,0 +1,337 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# stable diffusion x loca
|
| 2 |
+
import os
|
| 3 |
+
# os.system("source /etc/network_turbo")
|
| 4 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = "2"
|
| 5 |
+
import pprint
|
| 6 |
+
from typing import Any, List, Optional
|
| 7 |
+
import argparse
|
| 8 |
+
import pyrallis
|
| 9 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 10 |
+
import torch
|
| 11 |
+
import os
|
| 12 |
+
from PIL import Image
|
| 13 |
+
import numpy as np
|
| 14 |
+
from config import RunConfig
|
| 15 |
+
from _utils import attn_utils_new as attn_utils
|
| 16 |
+
from _utils.attn_utils import AttentionStore
|
| 17 |
+
from _utils.misc_helper import *
|
| 18 |
+
import torch.nn.functional as F
|
| 19 |
+
import matplotlib.pyplot as plt
|
| 20 |
+
import cv2
|
| 21 |
+
import warnings
|
| 22 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 23 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 24 |
+
import pytorch_lightning as pl
|
| 25 |
+
from _utils.load_models import load_stable_diffusion_model
|
| 26 |
+
from models.model import Counting_with_SD_features_loca as Counting
|
| 27 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 28 |
+
from models.enc_model.loca_args import get_argparser as loca_get_argparser
|
| 29 |
+
from models.enc_model.loca import build_model as build_loca_model
|
| 30 |
+
import time
|
| 31 |
+
import torchvision.transforms as T
|
| 32 |
+
import skimage.io as io
|
| 33 |
+
from _utils.dummy_box_gen import gen_dummy_boxes
|
| 34 |
+
|
| 35 |
+
SCALE = 1
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
class CountingModule(pl.LightningModule):
|
| 39 |
+
def __init__(self, use_box=True):
|
| 40 |
+
super().__init__()
|
| 41 |
+
self.use_box = use_box
|
| 42 |
+
self.config = RunConfig() # config for stable diffusion
|
| 43 |
+
self.initialize_model()
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def initialize_model(self):
|
| 47 |
+
|
| 48 |
+
# load loca model
|
| 49 |
+
loca_args = loca_get_argparser().parse_args()
|
| 50 |
+
self.loca_model = build_loca_model(loca_args)
|
| 51 |
+
# weights = torch.load("ckpt/loca_few_shot.pt")["model"]
|
| 52 |
+
# weights = {k.replace("module","") : v for k, v in weights.items()}
|
| 53 |
+
# self.loca_model.load_state_dict(weights, strict=False)
|
| 54 |
+
# del weights
|
| 55 |
+
|
| 56 |
+
self.counting_adapter = Counting(scale_factor=SCALE)
|
| 57 |
+
# if os.path.isfile(self.args.adapter_weight):
|
| 58 |
+
# adapter_weight = torch.load(self.args.adapter_weight,map_location=torch.device('cpu'))
|
| 59 |
+
# self.counting_adapter.load_state_dict(adapter_weight, strict=False)
|
| 60 |
+
|
| 61 |
+
### load stable diffusion and its controller
|
| 62 |
+
self.stable = load_stable_diffusion_model(config=self.config)
|
| 63 |
+
self.noise_scheduler = self.stable.scheduler
|
| 64 |
+
self.controller = AttentionStore(max_size=64)
|
| 65 |
+
attn_utils.register_attention_control(self.stable, self.controller)
|
| 66 |
+
attn_utils.register_hier_output(self.stable)
|
| 67 |
+
|
| 68 |
+
##### initialize token_emb #####
|
| 69 |
+
placeholder_token = "<task-prompt>"
|
| 70 |
+
self.task_token = "repetitive objects"
|
| 71 |
+
# Add the placeholder token in tokenizer
|
| 72 |
+
num_added_tokens = self.stable.tokenizer.add_tokens(placeholder_token)
|
| 73 |
+
if num_added_tokens == 0:
|
| 74 |
+
raise ValueError(
|
| 75 |
+
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
| 76 |
+
" `placeholder_token` that is not already in the tokenizer."
|
| 77 |
+
)
|
| 78 |
+
if os.path.isfile("pretrained/task_embed.pth"):
|
| 79 |
+
task_embed_from_pretrain = torch.load("pretrained/task_embed.pth")
|
| 80 |
+
placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token)
|
| 81 |
+
self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer))
|
| 82 |
+
|
| 83 |
+
token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data
|
| 84 |
+
token_embeds[placeholder_token_id] = task_embed_from_pretrain
|
| 85 |
+
else:
|
| 86 |
+
initializer_token = "count"
|
| 87 |
+
token_ids = self.stable.tokenizer.encode(initializer_token, add_special_tokens=False)
|
| 88 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
| 89 |
+
if len(token_ids) > 1:
|
| 90 |
+
raise ValueError("The initializer token must be a single token.")
|
| 91 |
+
|
| 92 |
+
initializer_token_id = token_ids[0]
|
| 93 |
+
placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token)
|
| 94 |
+
|
| 95 |
+
self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer))
|
| 96 |
+
|
| 97 |
+
token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data
|
| 98 |
+
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
| 99 |
+
|
| 100 |
+
# others
|
| 101 |
+
self.placeholder_token = placeholder_token
|
| 102 |
+
self.placeholder_token_id = placeholder_token_id
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def move_to_device(self, device):
|
| 106 |
+
self.stable.to(device)
|
| 107 |
+
if self.loca_model is not None and self.counting_adapter is not None:
|
| 108 |
+
self.loca_model.to(device)
|
| 109 |
+
self.counting_adapter.to(device)
|
| 110 |
+
self.to(device)
|
| 111 |
+
|
| 112 |
+
def forward(self, data_path, box=None):
|
| 113 |
+
filename = data_path.split("/")[-1]
|
| 114 |
+
img = Image.open(data_path).convert("RGB")
|
| 115 |
+
width, height = img.size
|
| 116 |
+
input_image = T.Compose([T.ToTensor(), T.Resize((512, 512))])(img)
|
| 117 |
+
input_image_stable = input_image - 0.5
|
| 118 |
+
input_image = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_image)
|
| 119 |
+
if box is not None:
|
| 120 |
+
boxes = torch.tensor(box) / torch.tensor([width, height, width, height]) * 512 # xyxy, normalized
|
| 121 |
+
assert self.use_box == True
|
| 122 |
+
else:
|
| 123 |
+
boxes = torch.tensor([[100,100,130,130], [200,200,250,250]], dtype=torch.float32) # dummy box
|
| 124 |
+
assert self.use_box == False
|
| 125 |
+
|
| 126 |
+
# move to device
|
| 127 |
+
input_image = input_image.unsqueeze(0).to(self.device)
|
| 128 |
+
boxes = boxes.unsqueeze(0).to(self.device)
|
| 129 |
+
input_image_stable = input_image_stable.unsqueeze(0).to(self.device)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach()
|
| 134 |
+
latents = latents * 0.18215
|
| 135 |
+
# Sample noise that we'll add to the latents
|
| 136 |
+
noise = torch.randn_like(latents)
|
| 137 |
+
timesteps = torch.tensor([20], device=latents.device).long()
|
| 138 |
+
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
| 139 |
+
input_ids_ = self.stable.tokenizer(
|
| 140 |
+
self.placeholder_token + " repetitive objects",
|
| 141 |
+
# "object",
|
| 142 |
+
padding="max_length",
|
| 143 |
+
truncation=True,
|
| 144 |
+
max_length=self.stable.tokenizer.model_max_length,
|
| 145 |
+
return_tensors="pt",
|
| 146 |
+
)
|
| 147 |
+
input_ids = input_ids_["input_ids"].to(self.device)
|
| 148 |
+
attention_mask = input_ids_["attention_mask"].to(self.device)
|
| 149 |
+
encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0]
|
| 150 |
+
|
| 151 |
+
input_image = input_image.to(self.device)
|
| 152 |
+
boxes = boxes.to(self.device)
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id)
|
| 156 |
+
if self.use_box:
|
| 157 |
+
loca_out = self.loca_model.forward_before_reg(input_image, boxes)
|
| 158 |
+
loca_feature_bf_regression = loca_out["feature_bf_regression"]
|
| 159 |
+
adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768]
|
| 160 |
+
if task_loc_idx.shape[0] == 0:
|
| 161 |
+
encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位
|
| 162 |
+
else:
|
| 163 |
+
encoder_hidden_states[0,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位
|
| 164 |
+
|
| 165 |
+
# Predict the noise residual
|
| 166 |
+
noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states)
|
| 167 |
+
noise_pred = noise_pred.sample
|
| 168 |
+
attention_store = self.controller.attention_store
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
attention_maps = []
|
| 172 |
+
exemplar_attention_maps = []
|
| 173 |
+
exemplar_attention_maps1 = []
|
| 174 |
+
exemplar_attention_maps2 = []
|
| 175 |
+
exemplar_attention_maps3 = []
|
| 176 |
+
|
| 177 |
+
cross_self_task_attn_maps = []
|
| 178 |
+
cross_self_exe_attn_maps = []
|
| 179 |
+
|
| 180 |
+
# only use 64x64 self-attention
|
| 181 |
+
self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
|
| 182 |
+
prompts=[self.config.prompt], # 这里要改么
|
| 183 |
+
attention_store=self.controller,
|
| 184 |
+
res=64,
|
| 185 |
+
from_where=("up", "down"),
|
| 186 |
+
is_cross=False,
|
| 187 |
+
select=0
|
| 188 |
+
)
|
| 189 |
+
self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096]
|
| 190 |
+
prompts=[self.config.prompt], # 这里要改么
|
| 191 |
+
attention_store=self.controller,
|
| 192 |
+
res=32,
|
| 193 |
+
from_where=("up", "down"),
|
| 194 |
+
is_cross=False,
|
| 195 |
+
select=0
|
| 196 |
+
)
|
| 197 |
+
self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096]
|
| 198 |
+
prompts=[self.config.prompt], # 这里要改么
|
| 199 |
+
attention_store=self.controller,
|
| 200 |
+
res=16,
|
| 201 |
+
from_where=("up", "down"),
|
| 202 |
+
is_cross=False,
|
| 203 |
+
select=0
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# cross attention
|
| 207 |
+
for res in [32, 16]:
|
| 208 |
+
attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
|
| 209 |
+
prompts=[self.config.prompt], # 这里要改么
|
| 210 |
+
attention_store=self.controller,
|
| 211 |
+
res=res,
|
| 212 |
+
from_where=("up", "down"),
|
| 213 |
+
is_cross=True,
|
| 214 |
+
select=0
|
| 215 |
+
)
|
| 216 |
+
|
| 217 |
+
task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
|
| 218 |
+
attention_maps.append(task_attn_)
|
| 219 |
+
if self.use_box:
|
| 220 |
+
exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn
|
| 221 |
+
exemplar_attention_maps.append(exemplar_attns)
|
| 222 |
+
else:
|
| 223 |
+
exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0)
|
| 224 |
+
exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0)
|
| 225 |
+
exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0)
|
| 226 |
+
exemplar_attention_maps1.append(exemplar_attns1)
|
| 227 |
+
exemplar_attention_maps2.append(exemplar_attns2)
|
| 228 |
+
exemplar_attention_maps3.append(exemplar_attns3)
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))]
|
| 232 |
+
attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))])
|
| 233 |
+
task_attn_64 = torch.mean(attns, dim=0, keepdim=True)
|
| 234 |
+
cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64)
|
| 235 |
+
cross_self_task_attn_maps.append(cross_self_task_attn)
|
| 236 |
+
|
| 237 |
+
if self.use_box:
|
| 238 |
+
scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))]
|
| 239 |
+
attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))])
|
| 240 |
+
exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True)
|
| 241 |
+
|
| 242 |
+
cross_self_exe_attn = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64)
|
| 243 |
+
cross_self_exe_attn_maps.append(cross_self_exe_attn)
|
| 244 |
+
else:
|
| 245 |
+
scale_factors = [(64 // exemplar_attention_maps1[i].shape[-1]) for i in range(len(exemplar_attention_maps1))]
|
| 246 |
+
attns = torch.cat([F.interpolate(exemplar_attention_maps1[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps1))])
|
| 247 |
+
exemplar_attn_64_1 = torch.mean(attns, dim=0, keepdim=True)
|
| 248 |
+
|
| 249 |
+
scale_factors = [(64 // exemplar_attention_maps2[i].shape[-1]) for i in range(len(exemplar_attention_maps2))]
|
| 250 |
+
attns = torch.cat([F.interpolate(exemplar_attention_maps2[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps2))])
|
| 251 |
+
exemplar_attn_64_2 = torch.mean(attns, dim=0, keepdim=True)
|
| 252 |
+
|
| 253 |
+
scale_factors = [(64 // exemplar_attention_maps3[i].shape[-1]) for i in range(len(exemplar_attention_maps3))]
|
| 254 |
+
attns = torch.cat([F.interpolate(exemplar_attention_maps3[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps3))])
|
| 255 |
+
exemplar_attn_64_3 = torch.mean(attns, dim=0, keepdim=True)
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64)
|
| 259 |
+
cross_self_task_attn_maps.append(cross_self_task_attn)
|
| 260 |
+
|
| 261 |
+
# if self.args.merge_exemplar == "average":
|
| 262 |
+
cross_self_exe_attn1 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_1)
|
| 263 |
+
cross_self_exe_attn2 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_2)
|
| 264 |
+
cross_self_exe_attn3 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_3)
|
| 265 |
+
exemplar_attn_64 = (exemplar_attn_64_1 + exemplar_attn_64_2 + exemplar_attn_64_3) / 3
|
| 266 |
+
cross_self_exe_attn = (cross_self_exe_attn1 + cross_self_exe_attn2 + cross_self_exe_attn3) / 3
|
| 267 |
+
|
| 268 |
+
exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6)
|
| 269 |
+
|
| 270 |
+
attn_stack = [exemplar_attn_64 / 2, cross_self_exe_attn / 2, exemplar_attn_64, cross_self_exe_attn]
|
| 271 |
+
attn_stack = torch.cat(attn_stack, dim=1)
|
| 272 |
+
|
| 273 |
+
if not self.use_box:
|
| 274 |
+
|
| 275 |
+
# cross_self_exe_attn_np = cross_self_exe_attn.detach().squeeze().cpu().numpy()
|
| 276 |
+
# boxes = gen_dummy_boxes(cross_self_exe_attn_np, max_boxes=1)
|
| 277 |
+
# boxes = boxes.to(self.device)
|
| 278 |
+
|
| 279 |
+
loca_out = self.loca_model.forward_before_reg(input_image, boxes)
|
| 280 |
+
loca_feature_bf_regression = loca_out["feature_bf_regression"]
|
| 281 |
+
attn_out = self.loca_model.forward_reg(loca_out, attn_stack, feature_list[-1])
|
| 282 |
+
pred_density = attn_out["pred"].squeeze().cpu().numpy()
|
| 283 |
+
pred_cnt = pred_density.sum().item()
|
| 284 |
+
|
| 285 |
+
# resize pred_density to original image size
|
| 286 |
+
pred_density_rsz = cv2.resize(pred_density, (width, height), interpolation=cv2.INTER_CUBIC)
|
| 287 |
+
pred_density_rsz = pred_density_rsz / pred_density_rsz.sum() * pred_cnt
|
| 288 |
+
|
| 289 |
+
return pred_density_rsz, pred_cnt
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def inference(data_path, box=None, save_path="./example_imgs", visualize=False):
|
| 293 |
+
if box is not None:
|
| 294 |
+
use_box = True
|
| 295 |
+
else:
|
| 296 |
+
use_box = False
|
| 297 |
+
model = CountingModule(use_box=use_box)
|
| 298 |
+
load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_cnt.pth"), strict=True)
|
| 299 |
+
model.eval()
|
| 300 |
+
with torch.no_grad():
|
| 301 |
+
density_map, cnt = model(data_path, box)
|
| 302 |
+
|
| 303 |
+
if visualize:
|
| 304 |
+
img = io.imread(data_path)
|
| 305 |
+
if len(img.shape) == 3 and img.shape[2] > 3:
|
| 306 |
+
img = img[:,:,:3]
|
| 307 |
+
if len(img.shape) == 2:
|
| 308 |
+
img = np.stack([img]*3, axis=-1)
|
| 309 |
+
img_show = img.squeeze()
|
| 310 |
+
density_map_show = density_map.squeeze()
|
| 311 |
+
os.makedirs(save_path, exist_ok=True)
|
| 312 |
+
filename = data_path.split("/")[-1]
|
| 313 |
+
img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show))
|
| 314 |
+
fig, ax = plt.subplots(1,2, figsize=(12,6))
|
| 315 |
+
ax[0].imshow(img_show)
|
| 316 |
+
ax[0].axis('off')
|
| 317 |
+
ax[0].set_title(f"Input image")
|
| 318 |
+
ax[1].imshow(img_show)
|
| 319 |
+
ax[1].imshow(density_map_show, cmap='jet', alpha=0.5) # Overlay density map with some transparency
|
| 320 |
+
ax[1].axis('off')
|
| 321 |
+
ax[1].set_title(f"Predicted density map, count: {cnt:.1f}")
|
| 322 |
+
plt.tight_layout()
|
| 323 |
+
plt.savefig(os.path.join(save_path, filename.split(".")[0]+"_cnt.png"), dpi=300)
|
| 324 |
+
plt.close()
|
| 325 |
+
return density_map
|
| 326 |
+
|
| 327 |
+
def main():
|
| 328 |
+
|
| 329 |
+
inference(
|
| 330 |
+
data_path = "example_imgs/1977_Well_F-5_Field_1.png",
|
| 331 |
+
# box=[[150, 60, 183, 87]],
|
| 332 |
+
save_path = "./example_imgs",
|
| 333 |
+
visualize = True
|
| 334 |
+
)
|
| 335 |
+
|
| 336 |
+
if __name__ == "__main__":
|
| 337 |
+
main()
|
example_imgs/1977_Well_F-5_Field_1.png
ADDED
|
Git LFS Details
|
example_imgs/1977_Well_F-5_Field_1_seg.png
ADDED
|
Git LFS Details
|
models/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
models/enc_model/__init__.py
ADDED
|
File without changes
|
models/enc_model/backbone.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torchvision import models
|
| 5 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Backbone(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
name: str,
|
| 13 |
+
pretrained: bool,
|
| 14 |
+
dilation: bool,
|
| 15 |
+
reduction: int,
|
| 16 |
+
swav: bool,
|
| 17 |
+
requires_grad: bool
|
| 18 |
+
):
|
| 19 |
+
|
| 20 |
+
super(Backbone, self).__init__()
|
| 21 |
+
|
| 22 |
+
resnet = getattr(models, name)(
|
| 23 |
+
replace_stride_with_dilation=[False, False, dilation],
|
| 24 |
+
pretrained=pretrained, norm_layer=FrozenBatchNorm2d
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.backbone = resnet
|
| 28 |
+
self.reduction = reduction
|
| 29 |
+
|
| 30 |
+
if name == 'resnet50' and swav:
|
| 31 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 32 |
+
'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar',
|
| 33 |
+
map_location="cpu"
|
| 34 |
+
)
|
| 35 |
+
state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
|
| 36 |
+
self.backbone.load_state_dict(state_dict, strict=False)
|
| 37 |
+
|
| 38 |
+
# concatenation of layers 2, 3 and 4
|
| 39 |
+
self.num_channels = 896 if name in ['resnet18', 'resnet34'] else 3584
|
| 40 |
+
|
| 41 |
+
for n, param in self.backbone.named_parameters():
|
| 42 |
+
if 'layer2' not in n and 'layer3' not in n and 'layer4' not in n:
|
| 43 |
+
param.requires_grad_(False)
|
| 44 |
+
else:
|
| 45 |
+
param.requires_grad_(requires_grad)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
size = x.size(-2) // self.reduction, x.size(-1) // self.reduction
|
| 49 |
+
x = self.backbone.conv1(x)
|
| 50 |
+
x = self.backbone.bn1(x)
|
| 51 |
+
x = self.backbone.relu(x)
|
| 52 |
+
x = self.backbone.maxpool(x)
|
| 53 |
+
|
| 54 |
+
x = self.backbone.layer1(x)
|
| 55 |
+
x = layer2 = self.backbone.layer2(x)
|
| 56 |
+
x = layer3 = self.backbone.layer3(x)
|
| 57 |
+
x = layer4 = self.backbone.layer4(x)
|
| 58 |
+
|
| 59 |
+
x = torch.cat([
|
| 60 |
+
F.interpolate(f, size=size, mode='bilinear', align_corners=True)
|
| 61 |
+
for f in [layer2, layer3, layer4]
|
| 62 |
+
], dim=1)
|
| 63 |
+
|
| 64 |
+
return x
|
models/enc_model/loca.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .backbone import Backbone
|
| 2 |
+
from .transformer import TransformerEncoder
|
| 3 |
+
from .ope import OPEModule
|
| 4 |
+
from .positional_encoding import PositionalEncodingsFixed
|
| 5 |
+
from .regression_head import DensityMapRegressor
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LOCA(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
image_size: int,
|
| 17 |
+
num_encoder_layers: int,
|
| 18 |
+
num_ope_iterative_steps: int,
|
| 19 |
+
num_objects: int,
|
| 20 |
+
emb_dim: int,
|
| 21 |
+
num_heads: int,
|
| 22 |
+
kernel_dim: int,
|
| 23 |
+
backbone_name: str,
|
| 24 |
+
swav_backbone: bool,
|
| 25 |
+
train_backbone: bool,
|
| 26 |
+
reduction: int,
|
| 27 |
+
dropout: float,
|
| 28 |
+
layer_norm_eps: float,
|
| 29 |
+
mlp_factor: int,
|
| 30 |
+
norm_first: bool,
|
| 31 |
+
activation: nn.Module,
|
| 32 |
+
norm: bool,
|
| 33 |
+
zero_shot: bool,
|
| 34 |
+
):
|
| 35 |
+
|
| 36 |
+
super(LOCA, self).__init__()
|
| 37 |
+
|
| 38 |
+
self.emb_dim = emb_dim
|
| 39 |
+
self.num_objects = num_objects
|
| 40 |
+
self.reduction = reduction
|
| 41 |
+
self.kernel_dim = kernel_dim
|
| 42 |
+
self.image_size = image_size
|
| 43 |
+
self.zero_shot = zero_shot
|
| 44 |
+
self.num_heads = num_heads
|
| 45 |
+
self.num_encoder_layers = num_encoder_layers
|
| 46 |
+
|
| 47 |
+
self.backbone = Backbone(
|
| 48 |
+
backbone_name, pretrained=True, dilation=False, reduction=reduction,
|
| 49 |
+
swav=swav_backbone, requires_grad=train_backbone
|
| 50 |
+
)
|
| 51 |
+
self.input_proj = nn.Conv2d(
|
| 52 |
+
self.backbone.num_channels, emb_dim, kernel_size=1
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if num_encoder_layers > 0:
|
| 56 |
+
self.encoder = TransformerEncoder(
|
| 57 |
+
num_encoder_layers, emb_dim, num_heads, dropout, layer_norm_eps,
|
| 58 |
+
mlp_factor, norm_first, activation, norm
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.ope = OPEModule(
|
| 62 |
+
num_ope_iterative_steps, emb_dim, kernel_dim, num_objects, num_heads,
|
| 63 |
+
reduction, layer_norm_eps, mlp_factor, norm_first, activation, norm, zero_shot
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.regression_head = DensityMapRegressor(emb_dim, reduction)
|
| 67 |
+
self.aux_heads = nn.ModuleList([
|
| 68 |
+
DensityMapRegressor(emb_dim, reduction)
|
| 69 |
+
for _ in range(num_ope_iterative_steps - 1)
|
| 70 |
+
])
|
| 71 |
+
|
| 72 |
+
self.pos_emb = PositionalEncodingsFixed(emb_dim)
|
| 73 |
+
|
| 74 |
+
self.attn_norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 75 |
+
self.fuse = nn.Sequential(
|
| 76 |
+
nn.Conv2d(324, 256, kernel_size=1, stride=1),
|
| 77 |
+
nn.LeakyReLU(),
|
| 78 |
+
nn.LayerNorm((64, 64))
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# self.fuse1 = nn.Sequential(
|
| 82 |
+
# nn.Conv2d(322, 256, kernel_size=1, stride=1),
|
| 83 |
+
# nn.LeakyReLU(),
|
| 84 |
+
# nn.LayerNorm((64, 64))
|
| 85 |
+
# )
|
| 86 |
+
|
| 87 |
+
def forward_before_reg(self, x, bboxes):
|
| 88 |
+
num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects
|
| 89 |
+
# backbone
|
| 90 |
+
backbone_features = self.backbone(x)
|
| 91 |
+
# prepare the encoder input
|
| 92 |
+
src = self.input_proj(backbone_features)
|
| 93 |
+
bs, c, h, w = src.size()
|
| 94 |
+
pos_emb = self.pos_emb(bs, h, w, src.device).flatten(2).permute(2, 0, 1)
|
| 95 |
+
src = src.flatten(2).permute(2, 0, 1)
|
| 96 |
+
|
| 97 |
+
# push through the encoder
|
| 98 |
+
if self.num_encoder_layers > 0:
|
| 99 |
+
image_features = self.encoder(src, pos_emb, src_key_padding_mask=None, src_mask=None)
|
| 100 |
+
else:
|
| 101 |
+
image_features = src
|
| 102 |
+
|
| 103 |
+
# prepare OPE input
|
| 104 |
+
f_e = image_features.permute(1, 2, 0).reshape(-1, self.emb_dim, h, w)
|
| 105 |
+
|
| 106 |
+
all_prototypes = self.ope(f_e, pos_emb, bboxes) # [3, 27, 1, 256]
|
| 107 |
+
|
| 108 |
+
outputs = list()
|
| 109 |
+
response_maps_list = []
|
| 110 |
+
for i in range(all_prototypes.size(0)):
|
| 111 |
+
prototypes = all_prototypes[i, ...].permute(1, 0, 2).reshape(
|
| 112 |
+
bs, num_objects, self.kernel_dim, self.kernel_dim, -1
|
| 113 |
+
).permute(0, 1, 4, 2, 3).flatten(0, 2)[:, None, ...] # [768, 1, 3, 3]
|
| 114 |
+
|
| 115 |
+
response_maps = F.conv2d(
|
| 116 |
+
torch.cat([f_e for _ in range(num_objects)], dim=1).flatten(0, 1).unsqueeze(0),
|
| 117 |
+
prototypes,
|
| 118 |
+
bias=None,
|
| 119 |
+
padding=self.kernel_dim // 2,
|
| 120 |
+
groups=prototypes.size(0)
|
| 121 |
+
).view(
|
| 122 |
+
bs, num_objects, self.emb_dim, h, w
|
| 123 |
+
).max(dim=1)[0]
|
| 124 |
+
|
| 125 |
+
# # send through regression heads
|
| 126 |
+
# if i == all_prototypes.size(0) - 1:
|
| 127 |
+
# predicted_dmaps = self.regression_head(response_maps)
|
| 128 |
+
# else:
|
| 129 |
+
# predicted_dmaps = self.aux_heads[i](response_maps)
|
| 130 |
+
# outputs.append(predicted_dmaps)
|
| 131 |
+
response_maps_list.append(response_maps)
|
| 132 |
+
|
| 133 |
+
out = {
|
| 134 |
+
# "pred": outputs[-1],
|
| 135 |
+
"feature_bf_regression": response_maps_list[-1],
|
| 136 |
+
# "aux_pred": outputs[:-1],
|
| 137 |
+
"aux_feature_bf_regression": response_maps_list[:-1]
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
return out
|
| 141 |
+
|
| 142 |
+
def forward_reg(self, response_maps, attn_stack, unet_feature):
|
| 143 |
+
attn_stack = self.attn_norm(attn_stack)
|
| 144 |
+
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 145 |
+
unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 146 |
+
unet_feature = unet_feature * attn_stack_mean
|
| 147 |
+
if unet_feature.shape[1] == 322:
|
| 148 |
+
unet_feature = self.fuse1(unet_feature)
|
| 149 |
+
else:
|
| 150 |
+
unet_feature = self.fuse(unet_feature)
|
| 151 |
+
|
| 152 |
+
response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
|
| 153 |
+
|
| 154 |
+
outputs = []
|
| 155 |
+
for i in range(len(response_maps)):
|
| 156 |
+
response_map = response_maps[i] + unet_feature
|
| 157 |
+
if i == len(response_maps) - 1:
|
| 158 |
+
predicted_dmaps = self.regression_head(response_map)
|
| 159 |
+
else:
|
| 160 |
+
predicted_dmaps = self.aux_heads[i](response_map)
|
| 161 |
+
outputs.append(predicted_dmaps)
|
| 162 |
+
|
| 163 |
+
return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
|
| 164 |
+
|
| 165 |
+
def forward_reg1(self, response_maps, self_attn):
|
| 166 |
+
# attn_stack = self.attn_norm(attn_stack)
|
| 167 |
+
# attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 168 |
+
# unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 169 |
+
# unet_feature = unet_feature * attn_stack_mean
|
| 170 |
+
# if unet_feature.shape[1] == 322:
|
| 171 |
+
# unet_feature = self.fuse1(unet_feature)
|
| 172 |
+
# else:
|
| 173 |
+
# unet_feature = self.fuse(unet_feature)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
|
| 178 |
+
|
| 179 |
+
outputs = []
|
| 180 |
+
for i in range(len(response_maps)):
|
| 181 |
+
response_map = response_maps[i] + self_attn
|
| 182 |
+
if i == len(response_maps) - 1:
|
| 183 |
+
predicted_dmaps = self.regression_head(response_map)
|
| 184 |
+
else:
|
| 185 |
+
predicted_dmaps = self.aux_heads[i](response_map)
|
| 186 |
+
outputs.append(predicted_dmaps)
|
| 187 |
+
|
| 188 |
+
return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
|
| 189 |
+
|
| 190 |
+
def forward_reg_without_unet(self, response_maps, attn_stack):
|
| 191 |
+
# attn_stack = self.attn_norm(attn_stack)
|
| 192 |
+
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 193 |
+
|
| 194 |
+
response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
|
| 195 |
+
|
| 196 |
+
outputs = []
|
| 197 |
+
for i in range(len(response_maps)):
|
| 198 |
+
response_map = response_maps[i] * attn_stack_mean * 0.5 + response_maps[i]
|
| 199 |
+
if i == len(response_maps) - 1:
|
| 200 |
+
predicted_dmaps = self.regression_head(response_map)
|
| 201 |
+
else:
|
| 202 |
+
predicted_dmaps = self.aux_heads[i](response_map)
|
| 203 |
+
outputs.append(predicted_dmaps)
|
| 204 |
+
|
| 205 |
+
return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def build_model(args):
|
| 209 |
+
|
| 210 |
+
assert args.backbone in ['resnet18', 'resnet50', 'resnet101']
|
| 211 |
+
assert args.reduction in [4, 8, 16]
|
| 212 |
+
|
| 213 |
+
return LOCA(
|
| 214 |
+
image_size=args.image_size,
|
| 215 |
+
num_encoder_layers=args.num_enc_layers,
|
| 216 |
+
num_ope_iterative_steps=args.num_ope_iterative_steps,
|
| 217 |
+
num_objects=args.num_objects,
|
| 218 |
+
zero_shot=args.zero_shot,
|
| 219 |
+
emb_dim=args.emb_dim,
|
| 220 |
+
num_heads=args.num_heads,
|
| 221 |
+
kernel_dim=args.kernel_dim,
|
| 222 |
+
backbone_name=args.backbone,
|
| 223 |
+
swav_backbone=args.swav_backbone,
|
| 224 |
+
train_backbone=args.backbone_lr > 0,
|
| 225 |
+
reduction=args.reduction,
|
| 226 |
+
dropout=args.dropout,
|
| 227 |
+
layer_norm_eps=1e-5,
|
| 228 |
+
mlp_factor=8,
|
| 229 |
+
norm_first=args.pre_norm,
|
| 230 |
+
activation=nn.GELU,
|
| 231 |
+
norm=True,
|
| 232 |
+
)
|
models/enc_model/loca_args.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_argparser():
|
| 5 |
+
|
| 6 |
+
parser = argparse.ArgumentParser("LOCA parser", add_help=False)
|
| 7 |
+
|
| 8 |
+
parser.add_argument('--model_name', default='loca_few_shot', type=str)
|
| 9 |
+
parser.add_argument(
|
| 10 |
+
'--data_path',
|
| 11 |
+
default='./data/FSC147_384_V2',
|
| 12 |
+
type=str
|
| 13 |
+
)
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
'--model_path',
|
| 16 |
+
default='ckpt',
|
| 17 |
+
type=str
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument('--backbone', default='resnet50', type=str)
|
| 20 |
+
parser.add_argument('--swav_backbone', action='store_true', default=True)
|
| 21 |
+
parser.add_argument('--reduction', default=8, type=int)
|
| 22 |
+
parser.add_argument('--image_size', default=512, type=int)
|
| 23 |
+
parser.add_argument('--num_enc_layers', default=3, type=int)
|
| 24 |
+
parser.add_argument('--num_ope_iterative_steps', default=3, type=int)
|
| 25 |
+
parser.add_argument('--emb_dim', default=256, type=int)
|
| 26 |
+
parser.add_argument('--num_heads', default=8, type=int)
|
| 27 |
+
parser.add_argument('--kernel_dim', default=3, type=int)
|
| 28 |
+
parser.add_argument('--num_objects', default=3, type=int)
|
| 29 |
+
parser.add_argument('--epochs', default=200, type=int)
|
| 30 |
+
parser.add_argument('--resume_training', action='store_true')
|
| 31 |
+
parser.add_argument('--lr', default=1e-4, type=float)
|
| 32 |
+
parser.add_argument('--backbone_lr', default=0, type=float)
|
| 33 |
+
parser.add_argument('--lr_drop', default=200, type=int)
|
| 34 |
+
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
| 35 |
+
parser.add_argument('--batch_size', default=1, type=int)
|
| 36 |
+
parser.add_argument('--dropout', default=0.1, type=float)
|
| 37 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
| 38 |
+
parser.add_argument('--max_grad_norm', default=0.1, type=float)
|
| 39 |
+
parser.add_argument('--aux_weight', default=0.3, type=float)
|
| 40 |
+
parser.add_argument('--tiling_p', default=0.5, type=float)
|
| 41 |
+
parser.add_argument('--zero_shot', action='store_true')
|
| 42 |
+
parser.add_argument('--pre_norm', action='store_true', default=True)
|
| 43 |
+
|
| 44 |
+
return parser
|
models/enc_model/mlp.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class MLP(nn.Module):
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
input_dim: int,
|
| 9 |
+
hidden_dim: int,
|
| 10 |
+
dropout: float,
|
| 11 |
+
activation: nn.Module
|
| 12 |
+
):
|
| 13 |
+
super(MLP, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.linear1 = nn.Linear(input_dim, hidden_dim)
|
| 16 |
+
self.linear2 = nn.Linear(hidden_dim, input_dim)
|
| 17 |
+
self.dropout = nn.Dropout(dropout)
|
| 18 |
+
self.activation = activation()
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return (
|
| 22 |
+
self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 23 |
+
)
|
models/enc_model/ope.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mlp import MLP
|
| 2 |
+
from .positional_encoding import PositionalEncodingsFixed
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from torchvision.ops import roi_align
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class OPEModule(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
num_iterative_steps: int,
|
| 15 |
+
emb_dim: int,
|
| 16 |
+
kernel_dim: int,
|
| 17 |
+
num_objects: int,
|
| 18 |
+
num_heads: int,
|
| 19 |
+
reduction: int,
|
| 20 |
+
layer_norm_eps: float,
|
| 21 |
+
mlp_factor: int,
|
| 22 |
+
norm_first: bool,
|
| 23 |
+
activation: nn.Module,
|
| 24 |
+
norm: bool,
|
| 25 |
+
zero_shot: bool,
|
| 26 |
+
):
|
| 27 |
+
|
| 28 |
+
super(OPEModule, self).__init__()
|
| 29 |
+
|
| 30 |
+
self.num_iterative_steps = num_iterative_steps
|
| 31 |
+
self.zero_shot = zero_shot
|
| 32 |
+
self.kernel_dim = kernel_dim
|
| 33 |
+
self.num_objects = num_objects
|
| 34 |
+
self.emb_dim = emb_dim
|
| 35 |
+
self.reduction = reduction
|
| 36 |
+
|
| 37 |
+
if num_iterative_steps > 0:
|
| 38 |
+
self.iterative_adaptation = IterativeAdaptationModule(
|
| 39 |
+
num_layers=num_iterative_steps, emb_dim=emb_dim, num_heads=num_heads,
|
| 40 |
+
dropout=0, layer_norm_eps=layer_norm_eps,
|
| 41 |
+
mlp_factor=mlp_factor, norm_first=norm_first,
|
| 42 |
+
activation=activation, norm=norm,
|
| 43 |
+
zero_shot=zero_shot
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if not self.zero_shot:
|
| 47 |
+
self.shape_or_objectness = nn.Sequential(
|
| 48 |
+
nn.Linear(2, 64),
|
| 49 |
+
nn.ReLU(),
|
| 50 |
+
nn.Linear(64, emb_dim),
|
| 51 |
+
nn.ReLU(),
|
| 52 |
+
nn.Linear(emb_dim, self.kernel_dim**2 * emb_dim)
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
self.shape_or_objectness = nn.Parameter(
|
| 56 |
+
torch.empty((self.num_objects, self.kernel_dim**2, emb_dim))
|
| 57 |
+
)
|
| 58 |
+
nn.init.normal_(self.shape_or_objectness)
|
| 59 |
+
|
| 60 |
+
self.pos_emb = PositionalEncodingsFixed(emb_dim)
|
| 61 |
+
|
| 62 |
+
def forward(self, f_e, pos_emb, bboxes):
|
| 63 |
+
bs, _, h, w = f_e.size()
|
| 64 |
+
# extract the shape features or objectness
|
| 65 |
+
if not self.zero_shot:
|
| 66 |
+
box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device)
|
| 67 |
+
box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0]
|
| 68 |
+
box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1]
|
| 69 |
+
shape_or_objectness = self.shape_or_objectness(box_hw).reshape(
|
| 70 |
+
bs, -1, self.kernel_dim ** 2, self.emb_dim
|
| 71 |
+
).flatten(1, 2).transpose(0, 1)
|
| 72 |
+
else:
|
| 73 |
+
shape_or_objectness = self.shape_or_objectness.expand(
|
| 74 |
+
bs, -1, -1, -1
|
| 75 |
+
).flatten(1, 2).transpose(0, 1)
|
| 76 |
+
|
| 77 |
+
# if not zero shot add appearance
|
| 78 |
+
if not self.zero_shot:
|
| 79 |
+
# reshape bboxes into the format suitable for roi_align
|
| 80 |
+
num_of_boxes = bboxes.size(1)
|
| 81 |
+
bboxes = torch.cat([
|
| 82 |
+
torch.arange(
|
| 83 |
+
bs, requires_grad=False
|
| 84 |
+
).to(bboxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 85 |
+
bboxes.flatten(0, 1),
|
| 86 |
+
], dim=1)
|
| 87 |
+
appearance = roi_align(
|
| 88 |
+
f_e,
|
| 89 |
+
boxes=bboxes, output_size=self.kernel_dim,
|
| 90 |
+
spatial_scale=1.0 / self.reduction, aligned=True
|
| 91 |
+
).permute(0, 2, 3, 1).reshape(
|
| 92 |
+
bs, num_of_boxes * self.kernel_dim ** 2, -1
|
| 93 |
+
).transpose(0, 1)
|
| 94 |
+
else:
|
| 95 |
+
num_of_boxes = self.num_objects
|
| 96 |
+
appearance = None
|
| 97 |
+
|
| 98 |
+
query_pos_emb = self.pos_emb(
|
| 99 |
+
bs, self.kernel_dim, self.kernel_dim, f_e.device
|
| 100 |
+
).flatten(2).permute(2, 0, 1).repeat(num_of_boxes, 1, 1)
|
| 101 |
+
|
| 102 |
+
if self.num_iterative_steps > 0:
|
| 103 |
+
memory = f_e.flatten(2).permute(2, 0, 1)
|
| 104 |
+
all_prototypes = self.iterative_adaptation(
|
| 105 |
+
shape_or_objectness, appearance, memory, pos_emb, query_pos_emb
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
if shape_or_objectness is not None and appearance is not None:
|
| 109 |
+
all_prototypes = (shape_or_objectness + appearance).unsqueeze(0)
|
| 110 |
+
else:
|
| 111 |
+
all_prototypes = (
|
| 112 |
+
shape_or_objectness if shape_or_objectness is not None else appearance
|
| 113 |
+
).unsqueeze(0)
|
| 114 |
+
|
| 115 |
+
return all_prototypes
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class IterativeAdaptationModule(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
num_layers: int,
|
| 124 |
+
emb_dim: int,
|
| 125 |
+
num_heads: int,
|
| 126 |
+
dropout: float,
|
| 127 |
+
layer_norm_eps: float,
|
| 128 |
+
mlp_factor: int,
|
| 129 |
+
norm_first: bool,
|
| 130 |
+
activation: nn.Module,
|
| 131 |
+
norm: bool,
|
| 132 |
+
zero_shot: bool
|
| 133 |
+
):
|
| 134 |
+
|
| 135 |
+
super(IterativeAdaptationModule, self).__init__()
|
| 136 |
+
|
| 137 |
+
self.layers = nn.ModuleList([
|
| 138 |
+
IterativeAdaptationLayer(
|
| 139 |
+
emb_dim, num_heads, dropout, layer_norm_eps,
|
| 140 |
+
mlp_factor, norm_first, activation, zero_shot
|
| 141 |
+
) for i in range(num_layers)
|
| 142 |
+
])
|
| 143 |
+
|
| 144 |
+
self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
|
| 145 |
+
|
| 146 |
+
def forward(
|
| 147 |
+
self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask=None, memory_mask=None,
|
| 148 |
+
tgt_key_padding_mask=None, memory_key_padding_mask=None
|
| 149 |
+
):
|
| 150 |
+
|
| 151 |
+
output = tgt
|
| 152 |
+
outputs = list()
|
| 153 |
+
for i, layer in enumerate(self.layers):
|
| 154 |
+
output = layer(
|
| 155 |
+
output, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask,
|
| 156 |
+
tgt_key_padding_mask, memory_key_padding_mask
|
| 157 |
+
)
|
| 158 |
+
outputs.append(self.norm(output))
|
| 159 |
+
|
| 160 |
+
return torch.stack(outputs)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class IterativeAdaptationLayer(nn.Module):
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
emb_dim: int,
|
| 168 |
+
num_heads: int,
|
| 169 |
+
dropout: float,
|
| 170 |
+
layer_norm_eps: float,
|
| 171 |
+
mlp_factor: int,
|
| 172 |
+
norm_first: bool,
|
| 173 |
+
activation: nn.Module,
|
| 174 |
+
zero_shot: bool
|
| 175 |
+
):
|
| 176 |
+
super(IterativeAdaptationLayer, self).__init__()
|
| 177 |
+
|
| 178 |
+
self.norm_first = norm_first
|
| 179 |
+
self.zero_shot = zero_shot
|
| 180 |
+
|
| 181 |
+
if not self.zero_shot:
|
| 182 |
+
self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
|
| 183 |
+
self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
|
| 184 |
+
self.norm3 = nn.LayerNorm(emb_dim, layer_norm_eps)
|
| 185 |
+
if not self.zero_shot:
|
| 186 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 187 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 188 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 189 |
+
|
| 190 |
+
if not self.zero_shot:
|
| 191 |
+
self.self_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout)
|
| 192 |
+
self.enc_dec_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout)
|
| 193 |
+
|
| 194 |
+
self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
|
| 195 |
+
|
| 196 |
+
def with_emb(self, x, emb):
|
| 197 |
+
return x if emb is None else x + emb
|
| 198 |
+
|
| 199 |
+
def forward(
|
| 200 |
+
self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask,
|
| 201 |
+
tgt_key_padding_mask, memory_key_padding_mask
|
| 202 |
+
):
|
| 203 |
+
if self.norm_first:
|
| 204 |
+
if not self.zero_shot:
|
| 205 |
+
tgt_norm = self.norm1(tgt)
|
| 206 |
+
tgt = tgt + self.dropout1(self.self_attn(
|
| 207 |
+
query=self.with_emb(tgt_norm, query_pos_emb),
|
| 208 |
+
key=self.with_emb(appearance, query_pos_emb),
|
| 209 |
+
value=appearance,
|
| 210 |
+
attn_mask=tgt_mask,
|
| 211 |
+
key_padding_mask=tgt_key_padding_mask
|
| 212 |
+
)[0])
|
| 213 |
+
|
| 214 |
+
tgt_norm = self.norm2(tgt)
|
| 215 |
+
tgt = tgt + self.dropout2(self.enc_dec_attn(
|
| 216 |
+
query=self.with_emb(tgt_norm, query_pos_emb),
|
| 217 |
+
key=memory+pos_emb,
|
| 218 |
+
value=memory,
|
| 219 |
+
attn_mask=memory_mask,
|
| 220 |
+
key_padding_mask=memory_key_padding_mask
|
| 221 |
+
)[0])
|
| 222 |
+
tgt_norm = self.norm3(tgt)
|
| 223 |
+
tgt = tgt + self.dropout3(self.mlp(tgt_norm))
|
| 224 |
+
|
| 225 |
+
else:
|
| 226 |
+
if not self.zero_shot:
|
| 227 |
+
tgt = self.norm1(tgt + self.dropout1(self.self_attn(
|
| 228 |
+
query=self.with_emb(tgt, query_pos_emb),
|
| 229 |
+
key=self.with_emb(appearance),
|
| 230 |
+
value=appearance,
|
| 231 |
+
attn_mask=tgt_mask,
|
| 232 |
+
key_padding_mask=tgt_key_padding_mask
|
| 233 |
+
)[0]))
|
| 234 |
+
|
| 235 |
+
tgt = self.norm2(tgt + self.dropout2(self.enc_dec_attn(
|
| 236 |
+
query=self.with_emb(tgt, query_pos_emb),
|
| 237 |
+
key=memory+pos_emb,
|
| 238 |
+
value=memory,
|
| 239 |
+
attn_mask=memory_mask,
|
| 240 |
+
key_padding_mask=memory_key_padding_mask
|
| 241 |
+
)[0]))
|
| 242 |
+
|
| 243 |
+
tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
|
| 244 |
+
|
| 245 |
+
return tgt
|
models/enc_model/positional_encoding.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PositionalEncodingsFixed(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, emb_dim, temperature=10000):
|
| 8 |
+
|
| 9 |
+
super(PositionalEncodingsFixed, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.emb_dim = emb_dim
|
| 12 |
+
self.temperature = temperature
|
| 13 |
+
|
| 14 |
+
def _1d_pos_enc(self, mask, dim):
|
| 15 |
+
temp = torch.arange(self.emb_dim // 2).float().to(mask.device)
|
| 16 |
+
temp = self.temperature ** (2 * (temp.div(2, rounding_mode='floor')) / self.emb_dim)
|
| 17 |
+
|
| 18 |
+
enc = (~mask).cumsum(dim).float().unsqueeze(-1) / temp
|
| 19 |
+
enc = torch.stack([
|
| 20 |
+
enc[..., 0::2].sin(), enc[..., 1::2].cos()
|
| 21 |
+
], dim=-1).flatten(-2)
|
| 22 |
+
|
| 23 |
+
return enc
|
| 24 |
+
|
| 25 |
+
def forward(self, bs, h, w, device):
|
| 26 |
+
mask = torch.zeros(bs, h, w, dtype=torch.bool, requires_grad=False, device=device)
|
| 27 |
+
x = self._1d_pos_enc(mask, dim=2)
|
| 28 |
+
y = self._1d_pos_enc(mask, dim=1)
|
| 29 |
+
|
| 30 |
+
return torch.cat([y, x], dim=3).permute(0, 3, 1, 2)
|
models/enc_model/regression_head.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class UpsamplingLayer(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, in_channels, out_channels, leaky=True):
|
| 8 |
+
|
| 9 |
+
super(UpsamplingLayer, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.layer = nn.Sequential(
|
| 12 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 13 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
| 14 |
+
nn.UpsamplingBilinear2d(scale_factor=2)
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.layer(x)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DensityMapRegressor(nn.Module):
|
| 22 |
+
|
| 23 |
+
def __init__(self, in_channels, reduction):
|
| 24 |
+
|
| 25 |
+
super(DensityMapRegressor, self).__init__()
|
| 26 |
+
|
| 27 |
+
if reduction == 8:
|
| 28 |
+
self.regressor = nn.Sequential(
|
| 29 |
+
UpsamplingLayer(in_channels, 128),
|
| 30 |
+
UpsamplingLayer(128, 64),
|
| 31 |
+
UpsamplingLayer(64, 32),
|
| 32 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 33 |
+
nn.LeakyReLU()
|
| 34 |
+
)
|
| 35 |
+
elif reduction == 16:
|
| 36 |
+
self.regressor = nn.Sequential(
|
| 37 |
+
UpsamplingLayer(in_channels, 128),
|
| 38 |
+
UpsamplingLayer(128, 64),
|
| 39 |
+
UpsamplingLayer(64, 32),
|
| 40 |
+
UpsamplingLayer(32, 16),
|
| 41 |
+
nn.Conv2d(16, 1, kernel_size=1),
|
| 42 |
+
nn.LeakyReLU()
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.reset_parameters()
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
return self.regressor(x)
|
| 49 |
+
|
| 50 |
+
def reset_parameters(self):
|
| 51 |
+
for module in self.modules():
|
| 52 |
+
if isinstance(module, nn.Conv2d):
|
| 53 |
+
nn.init.normal_(module.weight, std=0.01)
|
| 54 |
+
if module.bias is not None:
|
| 55 |
+
nn.init.constant_(module.bias, 0)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DensityMapRegressor_(nn.Module):
|
| 59 |
+
|
| 60 |
+
def __init__(self, in_channels, reduction):
|
| 61 |
+
|
| 62 |
+
super(DensityMapRegressor, self).__init__()
|
| 63 |
+
|
| 64 |
+
if reduction == 8:
|
| 65 |
+
self.regressor = nn.Sequential(
|
| 66 |
+
UpsamplingLayer(in_channels, 128),
|
| 67 |
+
UpsamplingLayer(128, 64),
|
| 68 |
+
UpsamplingLayer(64, 32),
|
| 69 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 70 |
+
nn.LeakyReLU()
|
| 71 |
+
)
|
| 72 |
+
elif reduction == 16:
|
| 73 |
+
self.regressor = nn.Sequential(
|
| 74 |
+
UpsamplingLayer(in_channels, 128),
|
| 75 |
+
UpsamplingLayer(128, 64),
|
| 76 |
+
UpsamplingLayer(64, 32),
|
| 77 |
+
UpsamplingLayer(32, 16),
|
| 78 |
+
nn.Conv2d(16, 1, kernel_size=1),
|
| 79 |
+
nn.LeakyReLU()
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.reset_parameters()
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
return self.regressor(x)
|
| 86 |
+
|
| 87 |
+
def reset_parameters(self):
|
| 88 |
+
for module in self.modules():
|
| 89 |
+
if isinstance(module, nn.Conv2d):
|
| 90 |
+
nn.init.normal_(module.weight, std=0.01)
|
| 91 |
+
if module.bias is not None:
|
| 92 |
+
nn.init.constant_(module.bias, 0)
|
models/enc_model/transformer.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mlp import MLP
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TransformerEncoder(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
num_layers: int,
|
| 11 |
+
emb_dim: int,
|
| 12 |
+
num_heads: int,
|
| 13 |
+
dropout: float,
|
| 14 |
+
layer_norm_eps: float,
|
| 15 |
+
mlp_factor: int,
|
| 16 |
+
norm_first: bool,
|
| 17 |
+
activation: nn.Module,
|
| 18 |
+
norm: bool,
|
| 19 |
+
):
|
| 20 |
+
|
| 21 |
+
super(TransformerEncoder, self).__init__()
|
| 22 |
+
|
| 23 |
+
self.layers = nn.ModuleList([
|
| 24 |
+
TransformerEncoderLayer(
|
| 25 |
+
emb_dim, num_heads, dropout, layer_norm_eps,
|
| 26 |
+
mlp_factor, norm_first, activation
|
| 27 |
+
) for _ in range(num_layers)
|
| 28 |
+
])
|
| 29 |
+
|
| 30 |
+
self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
|
| 31 |
+
|
| 32 |
+
def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
|
| 33 |
+
output = src
|
| 34 |
+
for layer in self.layers:
|
| 35 |
+
output = layer(output, pos_emb, src_mask, src_key_padding_mask)
|
| 36 |
+
return self.norm(output)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class TransformerEncoderLayer(nn.Module):
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
emb_dim: int,
|
| 44 |
+
num_heads: int,
|
| 45 |
+
dropout: float,
|
| 46 |
+
layer_norm_eps: float,
|
| 47 |
+
mlp_factor: int,
|
| 48 |
+
norm_first: bool,
|
| 49 |
+
activation: nn.Module,
|
| 50 |
+
):
|
| 51 |
+
super(TransformerEncoderLayer, self).__init__()
|
| 52 |
+
|
| 53 |
+
self.norm_first = norm_first
|
| 54 |
+
|
| 55 |
+
self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
|
| 56 |
+
self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
|
| 57 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 58 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 59 |
+
|
| 60 |
+
self.self_attn = nn.MultiheadAttention(
|
| 61 |
+
emb_dim, num_heads, dropout
|
| 62 |
+
)
|
| 63 |
+
self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
|
| 64 |
+
|
| 65 |
+
def with_emb(self, x, emb):
|
| 66 |
+
return x if emb is None else x + emb
|
| 67 |
+
|
| 68 |
+
def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
|
| 69 |
+
if self.norm_first:
|
| 70 |
+
src_norm = self.norm1(src)
|
| 71 |
+
q = k = src_norm + pos_emb
|
| 72 |
+
src = src + self.dropout1(self.self_attn(
|
| 73 |
+
query=q,
|
| 74 |
+
key=k,
|
| 75 |
+
value=src_norm,
|
| 76 |
+
attn_mask=src_mask,
|
| 77 |
+
key_padding_mask=src_key_padding_mask
|
| 78 |
+
)[0])
|
| 79 |
+
|
| 80 |
+
src_norm = self.norm2(src)
|
| 81 |
+
src = src + self.dropout2(self.mlp(src_norm))
|
| 82 |
+
else:
|
| 83 |
+
q = k = src + pos_emb
|
| 84 |
+
src = self.norm1(src + self.dropout1(self.self_attn(
|
| 85 |
+
query=q,
|
| 86 |
+
key=k,
|
| 87 |
+
value=src,
|
| 88 |
+
attn_mask=src_mask,
|
| 89 |
+
key_padding_mask=src_key_padding_mask
|
| 90 |
+
)[0]))
|
| 91 |
+
|
| 92 |
+
src = self.norm2(src + self.dropout2(self.mlp(src)))
|
| 93 |
+
|
| 94 |
+
return src
|
models/enc_model/unet_parts.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Parts of the U-Net model """
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DoubleConv(nn.Module):
|
| 9 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
| 12 |
+
super().__init__()
|
| 13 |
+
if not mid_channels:
|
| 14 |
+
mid_channels = out_channels
|
| 15 |
+
self.double_conv = nn.Sequential(
|
| 16 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
| 17 |
+
nn.BatchNorm2d(mid_channels),
|
| 18 |
+
nn.ReLU(inplace=True),
|
| 19 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 20 |
+
nn.BatchNorm2d(out_channels),
|
| 21 |
+
nn.ReLU(inplace=True)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
return self.double_conv(x)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Down(nn.Module):
|
| 29 |
+
"""Downscaling with maxpool then double conv"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, in_channels, out_channels):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.maxpool_conv = nn.Sequential(
|
| 34 |
+
nn.MaxPool2d(2),
|
| 35 |
+
DoubleConv(in_channels, out_channels)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
return self.maxpool_conv(x)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Up(nn.Module):
|
| 43 |
+
"""Upscaling then double conv"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
| 49 |
+
if bilinear:
|
| 50 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 51 |
+
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
| 52 |
+
else:
|
| 53 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
| 54 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
| 55 |
+
|
| 56 |
+
def forward(self, x1, x2):
|
| 57 |
+
x1 = self.up(x1)
|
| 58 |
+
# input is CHW
|
| 59 |
+
diffY = x2.size()[2] - x1.size()[2]
|
| 60 |
+
diffX = x2.size()[3] - x1.size()[3]
|
| 61 |
+
|
| 62 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
| 63 |
+
diffY // 2, diffY - diffY // 2])
|
| 64 |
+
# if you have padding issues, see
|
| 65 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
| 66 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
| 67 |
+
x = torch.cat([x2, x1], dim=1)
|
| 68 |
+
return self.conv(x)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class OutConv(nn.Module):
|
| 72 |
+
def __init__(self, in_channels, out_channels):
|
| 73 |
+
super(OutConv, self).__init__()
|
| 74 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
return self.conv(x)
|
models/model.py
ADDED
|
@@ -0,0 +1,991 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import os
|
| 5 |
+
import clip
|
| 6 |
+
import sys
|
| 7 |
+
from models.seg_post_model.cellpose.models import CellposeModel
|
| 8 |
+
|
| 9 |
+
from torchvision.ops import roi_align
|
| 10 |
+
def crop_roi_feat(feat, boxes):
|
| 11 |
+
"""
|
| 12 |
+
feat: 1 x c x h x w
|
| 13 |
+
boxes: m x 4, 4: [y_tl, x_tl, y_br, x_br]
|
| 14 |
+
"""
|
| 15 |
+
_, _, h, w = feat.shape
|
| 16 |
+
out_stride = 512 / h
|
| 17 |
+
boxes_scaled = boxes / out_stride
|
| 18 |
+
boxes_scaled[:, :2] = torch.floor(boxes_scaled[:, :2]) # y_tl, x_tl: floor
|
| 19 |
+
boxes_scaled[:, 2:] = torch.ceil(boxes_scaled[:, 2:]) # y_br, x_br: ceil
|
| 20 |
+
boxes_scaled[:, :2] = torch.clamp_min(boxes_scaled[:, :2], 0)
|
| 21 |
+
boxes_scaled[:, 2] = torch.clamp_max(boxes_scaled[:, 2], h)
|
| 22 |
+
boxes_scaled[:, 3] = torch.clamp_max(boxes_scaled[:, 3], w)
|
| 23 |
+
feat_boxes = []
|
| 24 |
+
for idx_box in range(0, boxes.shape[0]):
|
| 25 |
+
y_tl, x_tl, y_br, x_br = boxes_scaled[idx_box]
|
| 26 |
+
y_tl, x_tl, y_br, x_br = int(y_tl), int(x_tl), int(y_br), int(x_br)
|
| 27 |
+
feat_box = feat[:, :, y_tl : (y_br + 1), x_tl : (x_br + 1)]
|
| 28 |
+
feat_boxes.append(feat_box)
|
| 29 |
+
return feat_boxes
|
| 30 |
+
|
| 31 |
+
class Counting_with_SD_features(nn.Module):
|
| 32 |
+
def __init__(self, scale_factor):
|
| 33 |
+
super(Counting_with_SD_features, self).__init__()
|
| 34 |
+
self.adapter = adapter_roi()
|
| 35 |
+
# self.regressor = regressor_with_SD_features()
|
| 36 |
+
|
| 37 |
+
class Counting_with_SD_features_loca(nn.Module):
|
| 38 |
+
def __init__(self, scale_factor):
|
| 39 |
+
super(Counting_with_SD_features_loca, self).__init__()
|
| 40 |
+
self.adapter = adapter_roi_loca()
|
| 41 |
+
self.regressor = regressor_with_SD_features()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class Counting_with_SD_features_dino_vit_c3(nn.Module):
|
| 45 |
+
def __init__(self, scale_factor, vit=None):
|
| 46 |
+
super(Counting_with_SD_features_dino_vit_c3, self).__init__()
|
| 47 |
+
self.adapter = adapter_roi_loca()
|
| 48 |
+
self.regressor = regressor_with_SD_features_seg_vit_c3()
|
| 49 |
+
|
| 50 |
+
class Counting_with_SD_features_track(nn.Module):
|
| 51 |
+
def __init__(self, scale_factor, vit=None):
|
| 52 |
+
super(Counting_with_SD_features_track, self).__init__()
|
| 53 |
+
self.adapter = adapter_roi_loca()
|
| 54 |
+
self.regressor = regressor_with_SD_features_tra()
|
| 55 |
+
|
| 56 |
+
class Counting_with_SD_features_loca_rand(nn.Module):
|
| 57 |
+
def __init__(self, scale_factor, num_of_roi = 3):
|
| 58 |
+
super(Counting_with_SD_features_loca_rand, self).__init__()
|
| 59 |
+
self.adapter = adapter_roi_loca_rand(num_of_roi=num_of_roi)
|
| 60 |
+
self.regressor = regressor_with_SD_features()
|
| 61 |
+
|
| 62 |
+
class Counting_with_SD_features_loca_carpk(nn.Module):
|
| 63 |
+
def __init__(self, scale_factor, num_of_roi = 3):
|
| 64 |
+
super(Counting_with_SD_features_loca_carpk, self).__init__()
|
| 65 |
+
self.adapter = adapter_roi_loca_carpk(num_of_roi=num_of_roi)
|
| 66 |
+
self.regressor = regressor_with_SD_features()
|
| 67 |
+
|
| 68 |
+
class Counting_with_SD_features_clip_carpk(nn.Module):
|
| 69 |
+
def __init__(self, scale_factor, num_of_roi = 3):
|
| 70 |
+
super(Counting_with_SD_features_clip_carpk, self).__init__()
|
| 71 |
+
self.adapter = adapter_roi_clip_carpk(num_of_roi=num_of_roi)
|
| 72 |
+
# self.regressor = regressor_with_SD_features()
|
| 73 |
+
|
| 74 |
+
class Counting_with_SD_features_zero(nn.Module):
|
| 75 |
+
def __init__(self, scale_factor):
|
| 76 |
+
super(Counting_with_SD_features_zero, self).__init__()
|
| 77 |
+
self.adapter = adapter_roi_zero()
|
| 78 |
+
self.regressor = regressor_with_SD_features()
|
| 79 |
+
|
| 80 |
+
class Counting_with_SD_features_zero_loca(nn.Module):
|
| 81 |
+
def __init__(self, scale_factor):
|
| 82 |
+
super(Counting_with_SD_features_zero_loca, self).__init__()
|
| 83 |
+
self.adapter = adapter_roi_zero_loca()
|
| 84 |
+
self.regressor = regressor_with_SD_features()
|
| 85 |
+
|
| 86 |
+
class Counting_with_SD_features_zero_loca_self(nn.Module):
|
| 87 |
+
def __init__(self, scale_factor):
|
| 88 |
+
super(Counting_with_SD_features_zero_loca_self, self).__init__()
|
| 89 |
+
self.adapter = adapter_roi_zero_loca()
|
| 90 |
+
# self.regressor = regressor_with_SD_features_self()
|
| 91 |
+
self.regressor = regressor_with_SD_features_latent()
|
| 92 |
+
|
| 93 |
+
class Counting_with_SD_features_loca_v2(nn.Module):
|
| 94 |
+
def __init__(self, scale_factor):
|
| 95 |
+
super(Counting_with_SD_features_loca_v2, self).__init__()
|
| 96 |
+
self.adapter = adapter_roi_loca_v2()
|
| 97 |
+
# self.regressor = regressor_with_SD_features()
|
| 98 |
+
|
| 99 |
+
class adapter1(nn.Module):
|
| 100 |
+
def __init__(self):
|
| 101 |
+
super(adapter1, self).__init__()
|
| 102 |
+
self.conv1 = nn.Conv2d(256, 128, kernel_size=3, padding=1)
|
| 103 |
+
self.pool = nn.MaxPool2d(2)
|
| 104 |
+
self.fc = nn.Linear(128 * 64 * 64, 768)
|
| 105 |
+
self.initialize_weights()
|
| 106 |
+
|
| 107 |
+
def forward(self, x):
|
| 108 |
+
x = self.conv1(x)
|
| 109 |
+
x = self.pool(x)
|
| 110 |
+
x = x.view(x.size(0), -1)
|
| 111 |
+
x = self.fc(x)
|
| 112 |
+
return x
|
| 113 |
+
|
| 114 |
+
def initialize_weights(self):
|
| 115 |
+
for m in self.modules():
|
| 116 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 117 |
+
nn.init.xavier_normal_(m.weight)
|
| 118 |
+
if m.bias is not None:
|
| 119 |
+
nn.init.constant_(m.bias, 0)
|
| 120 |
+
|
| 121 |
+
class adapter(nn.Module):
|
| 122 |
+
def __init__(self, pool_size=[3, 3]):
|
| 123 |
+
super(adapter, self).__init__()
|
| 124 |
+
self.pool_size = pool_size
|
| 125 |
+
self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 126 |
+
self.pool = nn.MaxPool2d(2)
|
| 127 |
+
self.fc = nn.Linear(256 * 3 * 3, 768)
|
| 128 |
+
self.initialize_weights()
|
| 129 |
+
|
| 130 |
+
def forward(self, xs):
|
| 131 |
+
x_list = []
|
| 132 |
+
for x in xs:
|
| 133 |
+
x = F.adaptive_max_pool2d(x, self.pool_size, return_indices=False) # [1, 256, 3, 3]
|
| 134 |
+
x_list.append(x)
|
| 135 |
+
x_list = torch.cat(x_list, dim=0)
|
| 136 |
+
x_list = torch.mean(x_list, dim=0, keepdim=True) # [1, 256, 3, 3]
|
| 137 |
+
x = self.conv1(x_list)
|
| 138 |
+
# x = self.pool(x)
|
| 139 |
+
x = x.view(x.size(0), -1)
|
| 140 |
+
x = self.fc(x)
|
| 141 |
+
return x
|
| 142 |
+
|
| 143 |
+
def initialize_weights(self):
|
| 144 |
+
for m in self.modules():
|
| 145 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 146 |
+
nn.init.xavier_normal_(m.weight)
|
| 147 |
+
if m.bias is not None:
|
| 148 |
+
nn.init.constant_(m.bias, 0)
|
| 149 |
+
|
| 150 |
+
class adapter_roi(nn.Module):
|
| 151 |
+
def __init__(self, pool_size=[3, 3]):
|
| 152 |
+
super(adapter_roi, self).__init__()
|
| 153 |
+
self.pool_size = pool_size
|
| 154 |
+
self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 155 |
+
# self.relu = nn.ReLU()
|
| 156 |
+
# self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 157 |
+
self.pool = nn.MaxPool2d(2)
|
| 158 |
+
self.fc = nn.Linear(256 * 3 * 3, 768)
|
| 159 |
+
# **new
|
| 160 |
+
self.fc1 = nn.Sequential(
|
| 161 |
+
nn.ReLU(),
|
| 162 |
+
nn.Linear(768, 768 // 4, bias=False),
|
| 163 |
+
nn.ReLU()
|
| 164 |
+
)
|
| 165 |
+
self.fc2 = nn.Sequential(
|
| 166 |
+
nn.Linear(768 // 4, 768, bias=False),
|
| 167 |
+
# nn.ReLU()
|
| 168 |
+
)
|
| 169 |
+
self.initialize_weights()
|
| 170 |
+
|
| 171 |
+
def forward(self, x, boxes):
|
| 172 |
+
num_of_boxes = boxes.shape[1]
|
| 173 |
+
rois = []
|
| 174 |
+
bs, _, h, w = x.shape
|
| 175 |
+
boxes = torch.cat([
|
| 176 |
+
torch.arange(
|
| 177 |
+
bs, requires_grad=False
|
| 178 |
+
).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 179 |
+
boxes.flatten(0, 1),
|
| 180 |
+
], dim=1)
|
| 181 |
+
rois = roi_align(
|
| 182 |
+
x,
|
| 183 |
+
boxes=boxes, output_size=3,
|
| 184 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 185 |
+
)
|
| 186 |
+
rois = torch.mean(rois, dim=0, keepdim=True)
|
| 187 |
+
x = self.conv1(rois)
|
| 188 |
+
x = x.view(x.size(0), -1)
|
| 189 |
+
x = self.fc(x)
|
| 190 |
+
|
| 191 |
+
x = self.fc1(x)
|
| 192 |
+
x = self.fc2(x)
|
| 193 |
+
return x
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
def initialize_weights(self):
|
| 197 |
+
for m in self.modules():
|
| 198 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 199 |
+
nn.init.xavier_normal_(m.weight)
|
| 200 |
+
if m.bias is not None:
|
| 201 |
+
nn.init.constant_(m.bias, 0)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class adapter_roi_loca(nn.Module):
|
| 205 |
+
def __init__(self, pool_size=[3, 3]):
|
| 206 |
+
super(adapter_roi_loca, self).__init__()
|
| 207 |
+
self.pool_size = pool_size
|
| 208 |
+
self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 209 |
+
self.pool = nn.MaxPool2d(2)
|
| 210 |
+
self.fc = nn.Linear(256 * 3 * 3, 768)
|
| 211 |
+
self.initialize_weights()
|
| 212 |
+
def forward(self, x, boxes):
|
| 213 |
+
num_of_boxes = boxes.shape[1]
|
| 214 |
+
rois = []
|
| 215 |
+
bs, _, h, w = x.shape
|
| 216 |
+
if h != 512 or w != 512:
|
| 217 |
+
x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False)
|
| 218 |
+
if bs == 1:
|
| 219 |
+
boxes = torch.cat([
|
| 220 |
+
torch.arange(
|
| 221 |
+
bs, requires_grad=False
|
| 222 |
+
).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 223 |
+
boxes.flatten(0, 1),
|
| 224 |
+
], dim=1)
|
| 225 |
+
rois = roi_align(
|
| 226 |
+
x,
|
| 227 |
+
boxes=boxes, output_size=3,
|
| 228 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 229 |
+
)
|
| 230 |
+
rois = torch.mean(rois, dim=0, keepdim=True)
|
| 231 |
+
else:
|
| 232 |
+
boxes = torch.cat([
|
| 233 |
+
boxes.flatten(0, 1),
|
| 234 |
+
], dim=1).split(num_of_boxes, dim=0)
|
| 235 |
+
rois = roi_align(
|
| 236 |
+
x,
|
| 237 |
+
boxes=boxes, output_size=3,
|
| 238 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 239 |
+
)
|
| 240 |
+
rois = rois.split(num_of_boxes, dim=0)
|
| 241 |
+
rois = torch.stack(rois, dim=0)
|
| 242 |
+
rois = torch.mean(rois, dim=1, keepdim=False)
|
| 243 |
+
x = self.conv1(rois)
|
| 244 |
+
x = x.view(x.size(0), -1)
|
| 245 |
+
x = self.fc(x)
|
| 246 |
+
return x
|
| 247 |
+
|
| 248 |
+
def forward_boxes(self, x, boxes):
|
| 249 |
+
num_of_boxes = boxes.shape[1]
|
| 250 |
+
rois = []
|
| 251 |
+
bs, _, h, w = x.shape
|
| 252 |
+
if h != 512 or w != 512:
|
| 253 |
+
x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False)
|
| 254 |
+
if bs == 1:
|
| 255 |
+
boxes = torch.cat([
|
| 256 |
+
torch.arange(
|
| 257 |
+
bs, requires_grad=False
|
| 258 |
+
).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 259 |
+
boxes.flatten(0, 1),
|
| 260 |
+
], dim=1)
|
| 261 |
+
rois = roi_align(
|
| 262 |
+
x,
|
| 263 |
+
boxes=boxes, output_size=3,
|
| 264 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 265 |
+
)
|
| 266 |
+
# rois = torch.mean(rois, dim=0, keepdim=True)
|
| 267 |
+
else:
|
| 268 |
+
raise NotImplementedError
|
| 269 |
+
x = self.conv1(rois)
|
| 270 |
+
x = x.view(x.size(0), -1)
|
| 271 |
+
x = self.fc(x)
|
| 272 |
+
return x
|
| 273 |
+
|
| 274 |
+
def initialize_weights(self):
|
| 275 |
+
for m in self.modules():
|
| 276 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 277 |
+
nn.init.xavier_normal_(m.weight)
|
| 278 |
+
if m.bias is not None:
|
| 279 |
+
nn.init.constant_(m.bias, 0)
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
class adapter_roi_dino(nn.Module):
|
| 283 |
+
def __init__(self, pool_size=[3, 3]):
|
| 284 |
+
super(adapter_roi_dino, self).__init__()
|
| 285 |
+
self.pool_size = pool_size
|
| 286 |
+
# self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 287 |
+
# self.pool = nn.MaxPool2d(2)
|
| 288 |
+
self.fc = nn.Linear(1024, 768)
|
| 289 |
+
self.initialize_weights()
|
| 290 |
+
def forward(self, crops, dino_model):
|
| 291 |
+
num_of_boxes = len(crops)
|
| 292 |
+
feats = []
|
| 293 |
+
for i in range(num_of_boxes):
|
| 294 |
+
with torch.no_grad():
|
| 295 |
+
feat = dino_model(crops[i])
|
| 296 |
+
|
| 297 |
+
feats.append(feat)
|
| 298 |
+
feats = torch.cat(feats, dim=0)
|
| 299 |
+
feats = torch.mean(feats, dim=0)
|
| 300 |
+
x = self.fc(feats)
|
| 301 |
+
return x
|
| 302 |
+
def initialize_weights(self):
|
| 303 |
+
for m in self.modules():
|
| 304 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 305 |
+
nn.init.xavier_normal_(m.weight)
|
| 306 |
+
if m.bias is not None:
|
| 307 |
+
nn.init.constant_(m.bias, 0)
|
| 308 |
+
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
class adapter_roi_loca_v2(nn.Module):
|
| 312 |
+
def __init__(self, pool_size=[3, 3]):
|
| 313 |
+
super(adapter_roi_loca_v2, self).__init__()
|
| 314 |
+
self.pool_size = pool_size
|
| 315 |
+
self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 316 |
+
self.pool = nn.MaxPool2d(2)
|
| 317 |
+
self.fc = nn.Linear(256 * 3 * 3, 1024)
|
| 318 |
+
self.initialize_weights()
|
| 319 |
+
def forward(self, x, boxes):
|
| 320 |
+
rois = []
|
| 321 |
+
bs, _, h, w = x.shape
|
| 322 |
+
boxes = torch.cat([
|
| 323 |
+
torch.arange(
|
| 324 |
+
bs, requires_grad=False
|
| 325 |
+
).to(boxes.device).repeat_interleave(3).reshape(-1, 1),
|
| 326 |
+
boxes.flatten(0, 1),
|
| 327 |
+
], dim=1)
|
| 328 |
+
rois = roi_align(
|
| 329 |
+
x,
|
| 330 |
+
boxes=boxes, output_size=3,
|
| 331 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 332 |
+
)
|
| 333 |
+
rois = torch.mean(rois, dim=0, keepdim=True)
|
| 334 |
+
x = self.conv1(rois)
|
| 335 |
+
x = x.view(x.size(0), -1)
|
| 336 |
+
x = self.fc(x)
|
| 337 |
+
return x
|
| 338 |
+
def initialize_weights(self):
|
| 339 |
+
for m in self.modules():
|
| 340 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 341 |
+
nn.init.xavier_normal_(m.weight)
|
| 342 |
+
if m.bias is not None:
|
| 343 |
+
nn.init.constant_(m.bias, 0)
|
| 344 |
+
|
| 345 |
+
class adapter_roi_zero(nn.Module):
|
| 346 |
+
def __init__(self, reduction=4):
|
| 347 |
+
super(adapter_roi_zero, self).__init__()
|
| 348 |
+
self.fc1 = nn.Sequential(
|
| 349 |
+
nn.Linear(768, 768 // reduction, bias=False),
|
| 350 |
+
nn.ReLU()
|
| 351 |
+
)
|
| 352 |
+
self.fc2 = nn.Sequential(
|
| 353 |
+
nn.Linear(768 // reduction, 768, bias=False),
|
| 354 |
+
nn.ReLU()
|
| 355 |
+
)
|
| 356 |
+
self.initialize_weights()
|
| 357 |
+
def forward(self, x):
|
| 358 |
+
x1 = self.fc1(x)
|
| 359 |
+
x1 = self.fc2(x1)
|
| 360 |
+
return x + x1
|
| 361 |
+
def initialize_weights(self):
|
| 362 |
+
for m in self.modules():
|
| 363 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 364 |
+
nn.init.xavier_normal_(m.weight)
|
| 365 |
+
if m.bias is not None:
|
| 366 |
+
nn.init.constant_(m.bias, 0)
|
| 367 |
+
|
| 368 |
+
class adapter_roi_zero_loca(nn.Module):
|
| 369 |
+
def __init__(self, reduction=4):
|
| 370 |
+
super(adapter_roi_zero_loca, self).__init__()
|
| 371 |
+
self.fc1 = nn.Sequential(
|
| 372 |
+
nn.Linear(768, 768 // reduction, bias=False),
|
| 373 |
+
nn.ReLU()
|
| 374 |
+
)
|
| 375 |
+
self.fc2 = nn.Sequential(
|
| 376 |
+
nn.Linear(768 // reduction, 768, bias=False),
|
| 377 |
+
nn.ReLU()
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
self.pool_size = (3, 3)
|
| 381 |
+
self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 382 |
+
self.pool = nn.MaxPool2d(2)
|
| 383 |
+
self.fc = nn.Linear(256 * 3 * 3, 768)
|
| 384 |
+
|
| 385 |
+
self.initialize_weights()
|
| 386 |
+
def forward(self, feature, boxes, class_emb):
|
| 387 |
+
x1 = self.fc1(class_emb)
|
| 388 |
+
x1 = self.fc2(x1)
|
| 389 |
+
class_emb = class_emb + x1
|
| 390 |
+
|
| 391 |
+
rois = []
|
| 392 |
+
bs, _, h, w = feature.shape
|
| 393 |
+
n_box = boxes.shape[1]
|
| 394 |
+
boxes = torch.cat([
|
| 395 |
+
torch.arange(
|
| 396 |
+
bs, requires_grad=False
|
| 397 |
+
).to(boxes.device).repeat_interleave(n_box).reshape(-1, 1),
|
| 398 |
+
boxes.flatten(0, 1),
|
| 399 |
+
], dim=1)
|
| 400 |
+
rois = roi_align(
|
| 401 |
+
feature,
|
| 402 |
+
boxes=boxes, output_size=3,
|
| 403 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 404 |
+
)
|
| 405 |
+
# rois = torch.mean(rois, dim=0, keepdim=True)
|
| 406 |
+
x = self.conv1(rois)
|
| 407 |
+
x = x.view(x.size(0), -1)
|
| 408 |
+
x = self.fc(x)
|
| 409 |
+
|
| 410 |
+
if len(class_emb.shape) == 3:
|
| 411 |
+
class_emb = class_emb.squeeze(1)
|
| 412 |
+
dist = torch.cosine_similarity(class_emb, x) # [n_box]
|
| 413 |
+
_, topk = torch.sort(dist[:10])
|
| 414 |
+
x_topk = x[topk[:3], :]
|
| 415 |
+
x_topk = torch.mean(x_topk, dim=0, keepdim=True)
|
| 416 |
+
return x_topk + class_emb
|
| 417 |
+
|
| 418 |
+
def vis(self, feature, boxes, class_emb):
|
| 419 |
+
x1 = self.fc1(class_emb)
|
| 420 |
+
x1 = self.fc2(x1)
|
| 421 |
+
class_emb = class_emb + x1
|
| 422 |
+
|
| 423 |
+
rois = []
|
| 424 |
+
bs, _, h, w = feature.shape
|
| 425 |
+
n_box = boxes.shape[1]
|
| 426 |
+
boxes = torch.cat([
|
| 427 |
+
torch.arange(
|
| 428 |
+
bs, requires_grad=False
|
| 429 |
+
).to(boxes.device).repeat_interleave(n_box).reshape(-1, 1),
|
| 430 |
+
boxes.flatten(0, 1),
|
| 431 |
+
], dim=1)
|
| 432 |
+
rois = roi_align(
|
| 433 |
+
feature,
|
| 434 |
+
boxes=boxes, output_size=3,
|
| 435 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 436 |
+
)
|
| 437 |
+
# rois = torch.mean(rois, dim=0, keepdim=True)
|
| 438 |
+
x = self.conv1(rois)
|
| 439 |
+
x = x.view(x.size(0), -1)
|
| 440 |
+
x = self.fc(x)
|
| 441 |
+
|
| 442 |
+
if len(class_emb.shape) == 3:
|
| 443 |
+
class_emb = class_emb.squeeze(1)
|
| 444 |
+
dist = torch.cosine_similarity(class_emb, x) # [n_box]
|
| 445 |
+
_, topk = torch.sort(dist[:10])
|
| 446 |
+
x_topk = x[topk[:3], :]
|
| 447 |
+
x_topk = torch.mean(x_topk, dim=0, keepdim=True)
|
| 448 |
+
return x_topk
|
| 449 |
+
|
| 450 |
+
def initialize_weights(self):
|
| 451 |
+
for m in self.modules():
|
| 452 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 453 |
+
nn.init.xavier_normal_(m.weight)
|
| 454 |
+
if m.bias is not None:
|
| 455 |
+
nn.init.constant_(m.bias, 0)
|
| 456 |
+
|
| 457 |
+
class adapter_roi_loca_rand(nn.Module):
|
| 458 |
+
def __init__(self, pool_size=[3, 3],num_of_roi = 3):
|
| 459 |
+
super(adapter_roi_loca_rand, self).__init__()
|
| 460 |
+
self.pool_size = pool_size
|
| 461 |
+
self.num_of_roi = num_of_roi
|
| 462 |
+
self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 463 |
+
self.pool = nn.MaxPool2d(2)
|
| 464 |
+
self.fc = nn.Linear(256 * 3 * 3, 768)
|
| 465 |
+
|
| 466 |
+
# # **new
|
| 467 |
+
# self.fc1 = nn.Sequential(
|
| 468 |
+
# nn.Linear(768, 768 // 4, bias=False),
|
| 469 |
+
# nn.ReLU()
|
| 470 |
+
# )
|
| 471 |
+
# self.fc2 = nn.Sequential(
|
| 472 |
+
# nn.Linear(768 // 4, 768, bias=False),
|
| 473 |
+
# nn.ReLU()
|
| 474 |
+
# )
|
| 475 |
+
# #
|
| 476 |
+
self.initialize_weights()
|
| 477 |
+
def forward(self, x, boxes, rand_boxes):
|
| 478 |
+
num_of_boxes = boxes.shape[1]
|
| 479 |
+
bs, _, h, w = x.shape
|
| 480 |
+
boxes = torch.cat([
|
| 481 |
+
torch.arange(
|
| 482 |
+
bs, requires_grad=False
|
| 483 |
+
).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 484 |
+
boxes.flatten(0, 1),
|
| 485 |
+
], dim=1)
|
| 486 |
+
rois = roi_align(
|
| 487 |
+
x,
|
| 488 |
+
boxes=boxes, output_size=3,
|
| 489 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 490 |
+
)
|
| 491 |
+
|
| 492 |
+
# new
|
| 493 |
+
num_of_boxes = rand_boxes.shape[1]
|
| 494 |
+
bs, _, h, w = x.shape
|
| 495 |
+
rand_boxes = torch.cat([
|
| 496 |
+
torch.arange(
|
| 497 |
+
bs, requires_grad=False
|
| 498 |
+
).to(rand_boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 499 |
+
rand_boxes.flatten(0, 1),
|
| 500 |
+
], dim=1)
|
| 501 |
+
rand_rois = roi_align(
|
| 502 |
+
x,
|
| 503 |
+
boxes=rand_boxes, output_size=3,
|
| 504 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 505 |
+
)
|
| 506 |
+
|
| 507 |
+
rois = torch.mean(rois, dim=0, keepdim=True)
|
| 508 |
+
|
| 509 |
+
# new
|
| 510 |
+
cos = torch.nn.CosineSimilarity(dim=1)
|
| 511 |
+
dist = cos(rois.view(1, -1), rand_rois.view(num_of_boxes, -1)) # [n_box]
|
| 512 |
+
_, topk = torch.sort(-dist)
|
| 513 |
+
x_topk = rand_rois[topk[:3], ...]
|
| 514 |
+
x_topk = torch.mean(x_topk, dim=0, keepdim=True)
|
| 515 |
+
|
| 516 |
+
rois += x_topk
|
| 517 |
+
|
| 518 |
+
x = self.conv1(rois)
|
| 519 |
+
x = x.view(x.size(0), -1)
|
| 520 |
+
x = self.fc(x)
|
| 521 |
+
# new
|
| 522 |
+
# x = self.fc1(x)
|
| 523 |
+
# x = self.fc2(x)
|
| 524 |
+
return x
|
| 525 |
+
|
| 526 |
+
def initialize_weights(self):
|
| 527 |
+
for m in self.modules():
|
| 528 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 529 |
+
nn.init.xavier_normal_(m.weight)
|
| 530 |
+
if m.bias is not None:
|
| 531 |
+
nn.init.constant_(m.bias, 0)
|
| 532 |
+
|
| 533 |
+
|
| 534 |
+
class regressor1(nn.Module):
|
| 535 |
+
def __init__(self):
|
| 536 |
+
super(regressor1, self).__init__()
|
| 537 |
+
self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 538 |
+
self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 539 |
+
self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
|
| 540 |
+
self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 541 |
+
self.leaky_relu = nn.LeakyReLU()
|
| 542 |
+
self.relu = nn.ReLU()
|
| 543 |
+
self.initialize_weights()
|
| 544 |
+
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def forward(self, x):
|
| 548 |
+
x_ = self.conv1(x)
|
| 549 |
+
x_ = self.leaky_relu(x_)
|
| 550 |
+
x_ = self.upsampler(x_)
|
| 551 |
+
x_ = self.conv2(x_)
|
| 552 |
+
x_ = self.leaky_relu(x_)
|
| 553 |
+
x_ = self.upsampler(x_)
|
| 554 |
+
x_ = self.conv3(x_)
|
| 555 |
+
x_ = self.relu(x_)
|
| 556 |
+
out = x_
|
| 557 |
+
return out
|
| 558 |
+
|
| 559 |
+
def initialize_weights(self):
|
| 560 |
+
for m in self.modules():
|
| 561 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 562 |
+
nn.init.xavier_normal_(m.weight)
|
| 563 |
+
if m.bias is not None:
|
| 564 |
+
nn.init.constant_(m.bias, 0)
|
| 565 |
+
|
| 566 |
+
|
| 567 |
+
class regressor1(nn.Module):
|
| 568 |
+
def __init__(self):
|
| 569 |
+
super(regressor1, self).__init__()
|
| 570 |
+
self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 571 |
+
self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 572 |
+
self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
|
| 573 |
+
self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 574 |
+
self.leaky_relu = nn.LeakyReLU()
|
| 575 |
+
self.relu = nn.ReLU()
|
| 576 |
+
|
| 577 |
+
def forward(self, x):
|
| 578 |
+
x_ = self.conv1(x)
|
| 579 |
+
x_ = self.leaky_relu(x_)
|
| 580 |
+
x_ = self.upsampler(x_)
|
| 581 |
+
x_ = self.conv2(x_)
|
| 582 |
+
x_ = self.leaky_relu(x_)
|
| 583 |
+
x_ = self.upsampler(x_)
|
| 584 |
+
x_ = self.conv3(x_)
|
| 585 |
+
x_ = self.relu(x_)
|
| 586 |
+
out = x_
|
| 587 |
+
return out
|
| 588 |
+
def initialize_weights(self):
|
| 589 |
+
for m in self.modules():
|
| 590 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 591 |
+
nn.init.xavier_normal_(m.weight)
|
| 592 |
+
if m.bias is not None:
|
| 593 |
+
nn.init.constant_(m.bias, 0)
|
| 594 |
+
|
| 595 |
+
|
| 596 |
+
class regressor_with_SD_features(nn.Module):
|
| 597 |
+
def __init__(self):
|
| 598 |
+
super(regressor_with_SD_features, self).__init__()
|
| 599 |
+
self.layer1 = nn.Sequential(
|
| 600 |
+
nn.Conv2d(324, 256, kernel_size=1, stride=1),
|
| 601 |
+
nn.LeakyReLU(),
|
| 602 |
+
nn.LayerNorm((64, 64))
|
| 603 |
+
)
|
| 604 |
+
self.layer2 = nn.Sequential(
|
| 605 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 606 |
+
nn.LeakyReLU(),
|
| 607 |
+
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 608 |
+
)
|
| 609 |
+
self.layer3 = nn.Sequential(
|
| 610 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 611 |
+
nn.ReLU(),
|
| 612 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 613 |
+
)
|
| 614 |
+
self.layer4 = nn.Sequential(
|
| 615 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 616 |
+
nn.LeakyReLU(),
|
| 617 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 618 |
+
)
|
| 619 |
+
self.conv = nn.Sequential(
|
| 620 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 621 |
+
nn.ReLU()
|
| 622 |
+
)
|
| 623 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 624 |
+
self.initialize_weights()
|
| 625 |
+
|
| 626 |
+
def forward(self, attn_stack, feature_list):
|
| 627 |
+
attn_stack = self.norm(attn_stack)
|
| 628 |
+
unet_feature = feature_list[-1]
|
| 629 |
+
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 630 |
+
unet_feature = unet_feature * attn_stack_mean
|
| 631 |
+
unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 632 |
+
x = self.layer1(unet_feature)
|
| 633 |
+
x = self.layer2(x)
|
| 634 |
+
x = self.layer3(x)
|
| 635 |
+
x = self.layer4(x)
|
| 636 |
+
out = self.conv(x)
|
| 637 |
+
return out / 100
|
| 638 |
+
|
| 639 |
+
def initialize_weights(self):
|
| 640 |
+
for m in self.modules():
|
| 641 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 642 |
+
nn.init.xavier_normal_(m.weight)
|
| 643 |
+
if m.bias is not None:
|
| 644 |
+
nn.init.constant_(m.bias, 0)
|
| 645 |
+
|
| 646 |
+
class regressor_with_SD_features_seg(nn.Module):
|
| 647 |
+
def __init__(self):
|
| 648 |
+
super(regressor_with_SD_features_seg, self).__init__()
|
| 649 |
+
self.layer1 = nn.Sequential(
|
| 650 |
+
nn.Conv2d(324, 256, kernel_size=1, stride=1),
|
| 651 |
+
nn.LeakyReLU(),
|
| 652 |
+
nn.LayerNorm((64, 64))
|
| 653 |
+
)
|
| 654 |
+
self.layer2 = nn.Sequential(
|
| 655 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 656 |
+
nn.LeakyReLU(),
|
| 657 |
+
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 658 |
+
)
|
| 659 |
+
self.layer3 = nn.Sequential(
|
| 660 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 661 |
+
nn.ReLU(),
|
| 662 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 663 |
+
)
|
| 664 |
+
self.layer4 = nn.Sequential(
|
| 665 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 666 |
+
nn.LeakyReLU(),
|
| 667 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 668 |
+
)
|
| 669 |
+
self.conv = nn.Sequential(
|
| 670 |
+
nn.Conv2d(32, 2, kernel_size=1),
|
| 671 |
+
# nn.ReLU()
|
| 672 |
+
)
|
| 673 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 674 |
+
self.initialize_weights()
|
| 675 |
+
|
| 676 |
+
def forward(self, attn_stack, feature_list):
|
| 677 |
+
attn_stack = self.norm(attn_stack)
|
| 678 |
+
unet_feature = feature_list[-1]
|
| 679 |
+
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 680 |
+
unet_feature = unet_feature * attn_stack_mean
|
| 681 |
+
unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 682 |
+
x = self.layer1(unet_feature)
|
| 683 |
+
x = self.layer2(x)
|
| 684 |
+
x = self.layer3(x)
|
| 685 |
+
x = self.layer4(x)
|
| 686 |
+
out = self.conv(x)
|
| 687 |
+
return out
|
| 688 |
+
|
| 689 |
+
def initialize_weights(self):
|
| 690 |
+
for m in self.modules():
|
| 691 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 692 |
+
nn.init.xavier_normal_(m.weight)
|
| 693 |
+
if m.bias is not None:
|
| 694 |
+
nn.init.constant_(m.bias, 0)
|
| 695 |
+
|
| 696 |
+
|
| 697 |
+
from models.enc_model.unet_parts import *
|
| 698 |
+
|
| 699 |
+
|
| 700 |
+
class regressor_with_SD_features_seg_vit_c3(nn.Module):
|
| 701 |
+
def __init__(self, n_channels=3, n_classes=2, bilinear=False):
|
| 702 |
+
super(regressor_with_SD_features_seg_vit_c3, self).__init__()
|
| 703 |
+
self.n_channels = n_channels
|
| 704 |
+
self.n_classes = n_classes
|
| 705 |
+
self.bilinear = bilinear
|
| 706 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 707 |
+
self.inc_0 = nn.Conv2d(n_channels, 3, kernel_size=3, padding=1)
|
| 708 |
+
self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
|
| 709 |
+
self.vit = self.vit_model.net
|
| 710 |
+
|
| 711 |
+
def forward(self, img, attn_stack, feature_list):
|
| 712 |
+
attn_stack = attn_stack[:, [1,3], ...]
|
| 713 |
+
attn_stack = self.norm(attn_stack)
|
| 714 |
+
unet_feature = feature_list[-1]
|
| 715 |
+
unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
|
| 716 |
+
|
| 717 |
+
x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 718 |
+
|
| 719 |
+
if x.shape[-1] != 512:
|
| 720 |
+
x = F.interpolate(x, size=(512, 512), mode="bilinear")
|
| 721 |
+
x = self.inc_0(x)
|
| 722 |
+
|
| 723 |
+
|
| 724 |
+
|
| 725 |
+
out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
|
| 726 |
+
out = torch.from_numpy(out).unsqueeze(0).to(x.device)
|
| 727 |
+
return out
|
| 728 |
+
|
| 729 |
+
def initialize_weights(self):
|
| 730 |
+
for m in self.modules():
|
| 731 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 732 |
+
nn.init.xavier_normal_(m.weight)
|
| 733 |
+
if m.bias is not None:
|
| 734 |
+
nn.init.constant_(m.bias, 0)
|
| 735 |
+
|
| 736 |
+
class regressor_with_SD_features_tra(nn.Module):
|
| 737 |
+
def __init__(self, n_channels=2, n_classes=2, bilinear=False):
|
| 738 |
+
super(regressor_with_SD_features_tra, self).__init__()
|
| 739 |
+
self.n_channels = n_channels
|
| 740 |
+
self.n_classes = n_classes
|
| 741 |
+
self.bilinear = bilinear
|
| 742 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 743 |
+
|
| 744 |
+
# segmentation
|
| 745 |
+
self.inc_0 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
|
| 746 |
+
self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
|
| 747 |
+
self.vit = self.vit_model.net
|
| 748 |
+
|
| 749 |
+
self.inc_1 = nn.Conv2d(n_channels, 1, kernel_size=3, padding=1)
|
| 750 |
+
self.mlp = nn.Linear(64 * 64, 320)
|
| 751 |
+
# self.vit = self.vit_model.net.float()
|
| 752 |
+
|
| 753 |
+
def forward_seg(self, img, attn_stack, feature_list, mask, training=False):
|
| 754 |
+
attn_stack = attn_stack[:, [1,3], ...]
|
| 755 |
+
attn_stack = self.norm(attn_stack)
|
| 756 |
+
unet_feature = feature_list[-1]
|
| 757 |
+
unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
|
| 758 |
+
x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 759 |
+
|
| 760 |
+
if x.shape[-1] != 512:
|
| 761 |
+
x = F.interpolate(x, size=(512, 512), mode="bilinear")
|
| 762 |
+
x = self.inc_0(x)
|
| 763 |
+
feat = x
|
| 764 |
+
|
| 765 |
+
out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
|
| 766 |
+
out = torch.from_numpy(out).unsqueeze(0).to(x.device)
|
| 767 |
+
return out, 0., feat
|
| 768 |
+
|
| 769 |
+
def forward(self, attn_prev, feature_list_prev, attn_after, feature_list_after):
|
| 770 |
+
assert attn_prev.shape == attn_after.shape, "attn_prev and attn_after must have the same shape"
|
| 771 |
+
n_instances = attn_prev.shape[0]
|
| 772 |
+
attn_prev = self.norm(attn_prev) # [n_instances, 1, 64, 64]
|
| 773 |
+
attn_after = self.norm(attn_after)
|
| 774 |
+
|
| 775 |
+
x = torch.cat([attn_prev, attn_after], dim=1) # n_instances, 2, 64, 64
|
| 776 |
+
|
| 777 |
+
x = self.inc_1(x)
|
| 778 |
+
x = x.view(1, n_instances, -1) # Flatten the tensor to [n_instances, 64*64*4]
|
| 779 |
+
x = self.mlp(x) # Apply the MLP to get the output
|
| 780 |
+
|
| 781 |
+
return x # Output shape will be [n_instances, 4]
|
| 782 |
+
|
| 783 |
+
|
| 784 |
+
|
| 785 |
+
def initialize_weights(self):
|
| 786 |
+
for m in self.modules():
|
| 787 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 788 |
+
nn.init.xavier_normal_(m.weight)
|
| 789 |
+
if m.bias is not None:
|
| 790 |
+
nn.init.constant_(m.bias, 0)
|
| 791 |
+
|
| 792 |
+
|
| 793 |
+
|
| 794 |
+
class regressor_with_SD_features_inst_seg_unet(nn.Module):
|
| 795 |
+
def __init__(self, n_channels=8, n_classes=3, bilinear=False):
|
| 796 |
+
super(regressor_with_SD_features_inst_seg_unet, self).__init__()
|
| 797 |
+
self.n_channels = n_channels
|
| 798 |
+
self.n_classes = n_classes
|
| 799 |
+
self.bilinear = bilinear
|
| 800 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 801 |
+
self.inc_0 = (DoubleConv(n_channels, 3))
|
| 802 |
+
self.inc = (DoubleConv(3, 64))
|
| 803 |
+
self.down1 = (Down(64, 128))
|
| 804 |
+
self.down2 = (Down(128, 256))
|
| 805 |
+
self.down3 = (Down(256, 512))
|
| 806 |
+
factor = 2 if bilinear else 1
|
| 807 |
+
self.down4 = (Down(512, 1024 // factor))
|
| 808 |
+
self.up1 = (Up(1024, 512 // factor, bilinear))
|
| 809 |
+
self.up2 = (Up(512, 256 // factor, bilinear))
|
| 810 |
+
self.up3 = (Up(256, 128 // factor, bilinear))
|
| 811 |
+
self.up4 = (Up(128, 64, bilinear))
|
| 812 |
+
self.outc = (OutConv(64, n_classes))
|
| 813 |
+
|
| 814 |
+
def forward(self, img, attn_stack, feature_list):
|
| 815 |
+
attn_stack = self.norm(attn_stack)
|
| 816 |
+
unet_feature = feature_list[-1]
|
| 817 |
+
unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
|
| 818 |
+
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 819 |
+
unet_feature_mean = unet_feature_mean * attn_stack_mean
|
| 820 |
+
x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 821 |
+
if x.shape[-1] != 512:
|
| 822 |
+
x = F.interpolate(x, size=(512, 512), mode="bilinear")
|
| 823 |
+
x = torch.cat([img, x], dim=1) # [1, 8, 512, 512]
|
| 824 |
+
x = self.inc_0(x)
|
| 825 |
+
x1 = self.inc(x)
|
| 826 |
+
x2 = self.down1(x1)
|
| 827 |
+
x3 = self.down2(x2)
|
| 828 |
+
x4 = self.down3(x3)
|
| 829 |
+
x5 = self.down4(x4)
|
| 830 |
+
x = self.up1(x5, x4)
|
| 831 |
+
x = self.up2(x, x3)
|
| 832 |
+
x = self.up3(x, x2)
|
| 833 |
+
x = self.up4(x, x1)
|
| 834 |
+
out = self.outc(x)
|
| 835 |
+
return out
|
| 836 |
+
|
| 837 |
+
def initialize_weights(self):
|
| 838 |
+
for m in self.modules():
|
| 839 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 840 |
+
nn.init.xavier_normal_(m.weight)
|
| 841 |
+
if m.bias is not None:
|
| 842 |
+
nn.init.constant_(m.bias, 0)
|
| 843 |
+
|
| 844 |
+
|
| 845 |
+
class regressor_with_SD_features_self(nn.Module):
|
| 846 |
+
def __init__(self):
|
| 847 |
+
super(regressor_with_SD_features_self, self).__init__()
|
| 848 |
+
self.layer = nn.Sequential(
|
| 849 |
+
nn.Conv2d(4096, 1024, kernel_size=1, stride=1),
|
| 850 |
+
nn.LeakyReLU(),
|
| 851 |
+
nn.LayerNorm((64, 64)),
|
| 852 |
+
nn.Conv2d(1024, 256, kernel_size=1, stride=1),
|
| 853 |
+
nn.LeakyReLU(),
|
| 854 |
+
nn.LayerNorm((64, 64)),
|
| 855 |
+
)
|
| 856 |
+
self.layer2 = nn.Sequential(
|
| 857 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 858 |
+
nn.LeakyReLU(),
|
| 859 |
+
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 860 |
+
)
|
| 861 |
+
self.layer3 = nn.Sequential(
|
| 862 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 863 |
+
nn.ReLU(),
|
| 864 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 865 |
+
)
|
| 866 |
+
self.layer4 = nn.Sequential(
|
| 867 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 868 |
+
nn.LeakyReLU(),
|
| 869 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 870 |
+
)
|
| 871 |
+
self.conv = nn.Sequential(
|
| 872 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 873 |
+
nn.ReLU()
|
| 874 |
+
)
|
| 875 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 876 |
+
self.initialize_weights()
|
| 877 |
+
|
| 878 |
+
def forward(self, self_attn):
|
| 879 |
+
self_attn = self_attn.permute(2, 0, 1)
|
| 880 |
+
self_attn = self.layer(self_attn)
|
| 881 |
+
return self_attn
|
| 882 |
+
# attn_stack = self.norm(attn_stack)
|
| 883 |
+
# unet_feature = feature_list[-1]
|
| 884 |
+
# attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 885 |
+
# unet_feature = unet_feature * attn_stack_mean
|
| 886 |
+
# unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 887 |
+
# x = self.layer(unet_feature)
|
| 888 |
+
# x = self.layer2(x)
|
| 889 |
+
# x = self.layer3(x)
|
| 890 |
+
# x = self.layer4(x)
|
| 891 |
+
# out = self.conv(x)
|
| 892 |
+
# return out / 100
|
| 893 |
+
|
| 894 |
+
def initialize_weights(self):
|
| 895 |
+
for m in self.modules():
|
| 896 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 897 |
+
nn.init.xavier_normal_(m.weight)
|
| 898 |
+
if m.bias is not None:
|
| 899 |
+
nn.init.constant_(m.bias, 0)
|
| 900 |
+
|
| 901 |
+
|
| 902 |
+
class regressor_with_SD_features_latent(nn.Module):
|
| 903 |
+
def __init__(self):
|
| 904 |
+
super(regressor_with_SD_features_latent, self).__init__()
|
| 905 |
+
self.layer = nn.Sequential(
|
| 906 |
+
nn.Conv2d(4, 256, kernel_size=1, stride=1),
|
| 907 |
+
nn.LeakyReLU(),
|
| 908 |
+
nn.LayerNorm((64, 64))
|
| 909 |
+
)
|
| 910 |
+
self.layer2 = nn.Sequential(
|
| 911 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 912 |
+
nn.LeakyReLU(),
|
| 913 |
+
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 914 |
+
)
|
| 915 |
+
self.layer3 = nn.Sequential(
|
| 916 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 917 |
+
nn.ReLU(),
|
| 918 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 919 |
+
)
|
| 920 |
+
self.layer4 = nn.Sequential(
|
| 921 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 922 |
+
nn.LeakyReLU(),
|
| 923 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 924 |
+
)
|
| 925 |
+
self.conv = nn.Sequential(
|
| 926 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 927 |
+
nn.ReLU()
|
| 928 |
+
)
|
| 929 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 930 |
+
self.initialize_weights()
|
| 931 |
+
|
| 932 |
+
def forward(self, self_attn):
|
| 933 |
+
# self_attn = self_attn.permute(2, 0, 1)
|
| 934 |
+
self_attn = self.layer(self_attn)
|
| 935 |
+
return self_attn
|
| 936 |
+
# attn_stack = self.norm(attn_stack)
|
| 937 |
+
# unet_feature = feature_list[-1]
|
| 938 |
+
# attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 939 |
+
# unet_feature = unet_feature * attn_stack_mean
|
| 940 |
+
# unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 941 |
+
# x = self.layer(unet_feature)
|
| 942 |
+
# x = self.layer2(x)
|
| 943 |
+
# x = self.layer3(x)
|
| 944 |
+
# x = self.layer4(x)
|
| 945 |
+
# out = self.conv(x)
|
| 946 |
+
# return out / 100
|
| 947 |
+
|
| 948 |
+
def initialize_weights(self):
|
| 949 |
+
for m in self.modules():
|
| 950 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 951 |
+
nn.init.xavier_normal_(m.weight)
|
| 952 |
+
if m.bias is not None:
|
| 953 |
+
nn.init.constant_(m.bias, 0)
|
| 954 |
+
|
| 955 |
+
|
| 956 |
+
|
| 957 |
+
|
| 958 |
+
|
| 959 |
+
class regressor_with_deconv(nn.Module):
|
| 960 |
+
def __init__(self):
|
| 961 |
+
super(regressor_with_deconv, self).__init__()
|
| 962 |
+
self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 963 |
+
self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 964 |
+
self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
|
| 965 |
+
self.deconv1 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
|
| 966 |
+
self.deconv2 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
|
| 967 |
+
self.leaky_relu = nn.LeakyReLU()
|
| 968 |
+
self.relu = nn.ReLU()
|
| 969 |
+
self.initialize_weights()
|
| 970 |
+
|
| 971 |
+
def forward(self, x):
|
| 972 |
+
x_ = self.conv1(x)
|
| 973 |
+
x_ = self.leaky_relu(x_)
|
| 974 |
+
x_ = self.deconv1(x_)
|
| 975 |
+
x_ = self.conv2(x_)
|
| 976 |
+
x_ = self.leaky_relu(x_)
|
| 977 |
+
x_ = self.deconv2(x_)
|
| 978 |
+
x_ = self.conv3(x_)
|
| 979 |
+
x_ = self.relu(x_)
|
| 980 |
+
out = x_
|
| 981 |
+
return out
|
| 982 |
+
|
| 983 |
+
def initialize_weights(self):
|
| 984 |
+
for m in self.modules():
|
| 985 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
|
| 986 |
+
nn.init.xavier_normal_(m.weight)
|
| 987 |
+
if m.bias is not None:
|
| 988 |
+
nn.init.constant_(m.bias, 0)
|
| 989 |
+
|
| 990 |
+
|
| 991 |
+
|
models/seg_post_model/cellpose/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .version import version, version_str
|
models/seg_post_model/cellpose/__main__.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import os, time
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from cellpose import utils, models, io, train
|
| 8 |
+
from .version import version_str
|
| 9 |
+
from cellpose.cli import get_arg_parser
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from cellpose.gui import gui3d, gui
|
| 13 |
+
GUI_ENABLED = True
|
| 14 |
+
except ImportError as err:
|
| 15 |
+
GUI_ERROR = err
|
| 16 |
+
GUI_ENABLED = False
|
| 17 |
+
GUI_IMPORT = True
|
| 18 |
+
except Exception as err:
|
| 19 |
+
GUI_ENABLED = False
|
| 20 |
+
GUI_ERROR = err
|
| 21 |
+
GUI_IMPORT = False
|
| 22 |
+
raise
|
| 23 |
+
|
| 24 |
+
import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
""" Run cellpose from command line
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
args = get_arg_parser().parse_args() # this has to be in a separate file for autodoc to work
|
| 32 |
+
|
| 33 |
+
if args.version:
|
| 34 |
+
print(version_str)
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
######## if no image arguments are provided, run GUI or add model and exit ########
|
| 38 |
+
if len(args.dir) == 0 and len(args.image_path) == 0:
|
| 39 |
+
if args.add_model:
|
| 40 |
+
io.add_model(args.add_model)
|
| 41 |
+
return
|
| 42 |
+
else:
|
| 43 |
+
if not GUI_ENABLED:
|
| 44 |
+
print("GUI ERROR: %s" % GUI_ERROR)
|
| 45 |
+
if GUI_IMPORT:
|
| 46 |
+
print(
|
| 47 |
+
"GUI FAILED: GUI dependencies may not be installed, to install, run"
|
| 48 |
+
)
|
| 49 |
+
print(" pip install 'cellpose[gui]'")
|
| 50 |
+
else:
|
| 51 |
+
if args.Zstack:
|
| 52 |
+
gui3d.run()
|
| 53 |
+
else:
|
| 54 |
+
gui.run()
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
############################## run cellpose on images ##############################
|
| 58 |
+
if args.verbose:
|
| 59 |
+
from .io import logger_setup
|
| 60 |
+
logger, log_file = logger_setup()
|
| 61 |
+
else:
|
| 62 |
+
print(
|
| 63 |
+
">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose")
|
| 64 |
+
print("No --verbose => no progress or info printed")
|
| 65 |
+
logger = logging.getLogger(__name__)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# find images
|
| 69 |
+
if len(args.img_filter) > 0:
|
| 70 |
+
image_filter = args.img_filter
|
| 71 |
+
else:
|
| 72 |
+
image_filter = None
|
| 73 |
+
|
| 74 |
+
device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu,
|
| 75 |
+
device=args.gpu_device)
|
| 76 |
+
|
| 77 |
+
if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0":
|
| 78 |
+
pretrained_model = "cpsam"
|
| 79 |
+
logger.warning("training from scratch is disabled, using 'cpsam' model")
|
| 80 |
+
else:
|
| 81 |
+
pretrained_model = args.pretrained_model
|
| 82 |
+
|
| 83 |
+
# Warn users about old arguments from CP3:
|
| 84 |
+
if args.pretrained_model_ortho:
|
| 85 |
+
logger.warning(
|
| 86 |
+
"the '--pretrained_model_ortho' flag is deprecated in v4.0.1+ and no longer used")
|
| 87 |
+
if args.train_size:
|
| 88 |
+
logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
|
| 89 |
+
if args.chan or args.chan2:
|
| 90 |
+
logger.warning('--chan and --chan2 are deprecated, all channels are used by default')
|
| 91 |
+
if args.all_channels:
|
| 92 |
+
logger.warning("the '--all_channels' flag is deprecated in v4.0.1+ and no longer used")
|
| 93 |
+
if args.restore_type:
|
| 94 |
+
logger.warning("the '--restore_type' flag is deprecated in v4.0.1+ and no longer used")
|
| 95 |
+
if args.transformer:
|
| 96 |
+
logger.warning("the '--tranformer' flag is deprecated in v4.0.1+ and no longer used")
|
| 97 |
+
if args.invert:
|
| 98 |
+
logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used")
|
| 99 |
+
if args.chan2_restore:
|
| 100 |
+
logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used")
|
| 101 |
+
if args.diam_mean:
|
| 102 |
+
logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used")
|
| 103 |
+
if args.train_size:
|
| 104 |
+
logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
|
| 105 |
+
|
| 106 |
+
if args.norm_percentile is not None:
|
| 107 |
+
value1, value2 = args.norm_percentile
|
| 108 |
+
normalize = {'percentile': (float(value1), float(value2))}
|
| 109 |
+
else:
|
| 110 |
+
normalize = (not args.no_norm)
|
| 111 |
+
|
| 112 |
+
if args.save_each:
|
| 113 |
+
if not args.save_every:
|
| 114 |
+
raise ValueError("ERROR: --save_each requires --save_every")
|
| 115 |
+
|
| 116 |
+
if len(args.image_path) > 0 and args.train:
|
| 117 |
+
raise ValueError("ERROR: cannot train model with single image input")
|
| 118 |
+
|
| 119 |
+
## Run evaluation on images
|
| 120 |
+
if not args.train:
|
| 121 |
+
_evaluate_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
|
| 122 |
+
|
| 123 |
+
## Train a model ##
|
| 124 |
+
else:
|
| 125 |
+
_train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize):
|
| 129 |
+
test_dir = None if len(args.test_dir) == 0 else args.test_dir
|
| 130 |
+
images, labels, image_names, train_probs = None, None, None, None
|
| 131 |
+
test_images, test_labels, image_names_test, test_probs = None, None, None, None
|
| 132 |
+
compute_flows = False
|
| 133 |
+
if len(args.file_list) > 0:
|
| 134 |
+
if os.path.exists(args.file_list):
|
| 135 |
+
dat = np.load(args.file_list, allow_pickle=True).item()
|
| 136 |
+
image_names = dat["train_files"]
|
| 137 |
+
image_names_test = dat.get("test_files", None)
|
| 138 |
+
train_probs = dat.get("train_probs", None)
|
| 139 |
+
test_probs = dat.get("test_probs", None)
|
| 140 |
+
compute_flows = dat.get("compute_flows", False)
|
| 141 |
+
load_files = False
|
| 142 |
+
else:
|
| 143 |
+
logger.critical(f"ERROR: {args.file_list} does not exist")
|
| 144 |
+
else:
|
| 145 |
+
output = io.load_train_test_data(args.dir, test_dir, image_filter,
|
| 146 |
+
args.mask_filter,
|
| 147 |
+
args.look_one_level_down)
|
| 148 |
+
images, labels, image_names, test_images, test_labels, image_names_test = output
|
| 149 |
+
load_files = True
|
| 150 |
+
|
| 151 |
+
# initialize model
|
| 152 |
+
model = models.CellposeModel(device=device, pretrained_model=pretrained_model)
|
| 153 |
+
|
| 154 |
+
# train segmentation model
|
| 155 |
+
cpmodel_path = train.train_seg(
|
| 156 |
+
model.net, images, labels, train_files=image_names,
|
| 157 |
+
test_data=test_images, test_labels=test_labels,
|
| 158 |
+
test_files=image_names_test, train_probs=train_probs,
|
| 159 |
+
test_probs=test_probs, compute_flows=compute_flows,
|
| 160 |
+
load_files=load_files, normalize=normalize,
|
| 161 |
+
channel_axis=args.channel_axis,
|
| 162 |
+
learning_rate=args.learning_rate, weight_decay=args.weight_decay,
|
| 163 |
+
SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.train_batch_size,
|
| 164 |
+
min_train_masks=args.min_train_masks,
|
| 165 |
+
nimg_per_epoch=args.nimg_per_epoch,
|
| 166 |
+
nimg_test_per_epoch=args.nimg_test_per_epoch,
|
| 167 |
+
save_path=os.path.realpath(args.dir),
|
| 168 |
+
save_every=args.save_every,
|
| 169 |
+
save_each=args.save_each,
|
| 170 |
+
model_name=args.model_name_out)[0]
|
| 171 |
+
model.pretrained_model = cpmodel_path
|
| 172 |
+
logger.info(">>>> model trained and saved to %s" % cpmodel_path)
|
| 173 |
+
return model
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, normalize):
|
| 177 |
+
# Check with user if they REALLY mean to run without saving anything
|
| 178 |
+
if not args.train:
|
| 179 |
+
saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt
|
| 180 |
+
|
| 181 |
+
tic = time.time()
|
| 182 |
+
if len(args.dir) > 0:
|
| 183 |
+
image_names = io.get_image_files(
|
| 184 |
+
args.dir, args.mask_filter, imf=imf,
|
| 185 |
+
look_one_level_down=args.look_one_level_down)
|
| 186 |
+
else:
|
| 187 |
+
if os.path.exists(args.image_path):
|
| 188 |
+
image_names = [args.image_path]
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(f"ERROR: no file found at {args.image_path}")
|
| 191 |
+
nimg = len(image_names)
|
| 192 |
+
|
| 193 |
+
if args.savedir:
|
| 194 |
+
if not os.path.exists(args.savedir):
|
| 195 |
+
raise FileExistsError(f"--savedir {args.savedir} does not exist")
|
| 196 |
+
|
| 197 |
+
logger.info(
|
| 198 |
+
">>>> running cellpose on %d images using all channels" % nimg)
|
| 199 |
+
|
| 200 |
+
# handle built-in model exceptions
|
| 201 |
+
model = models.CellposeModel(device=device, pretrained_model=pretrained_model,)
|
| 202 |
+
|
| 203 |
+
tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
|
| 204 |
+
|
| 205 |
+
channel_axis = args.channel_axis
|
| 206 |
+
z_axis = args.z_axis
|
| 207 |
+
|
| 208 |
+
for image_name in tqdm(image_names, file=tqdm_out):
|
| 209 |
+
if args.do_3D or args.stitch_threshold > 0.:
|
| 210 |
+
logger.info('loading image as 3D zstack')
|
| 211 |
+
image = io.imread_3D(image_name)
|
| 212 |
+
if channel_axis is None:
|
| 213 |
+
channel_axis = 3
|
| 214 |
+
if z_axis is None:
|
| 215 |
+
z_axis = 0
|
| 216 |
+
|
| 217 |
+
else:
|
| 218 |
+
image = io.imread_2D(image_name)
|
| 219 |
+
out = model.eval(
|
| 220 |
+
image,
|
| 221 |
+
diameter=args.diameter,
|
| 222 |
+
do_3D=args.do_3D,
|
| 223 |
+
augment=args.augment,
|
| 224 |
+
flow_threshold=args.flow_threshold,
|
| 225 |
+
cellprob_threshold=args.cellprob_threshold,
|
| 226 |
+
stitch_threshold=args.stitch_threshold,
|
| 227 |
+
min_size=args.min_size,
|
| 228 |
+
batch_size=args.batch_size,
|
| 229 |
+
bsize=args.bsize,
|
| 230 |
+
resample=not args.no_resample,
|
| 231 |
+
normalize=normalize,
|
| 232 |
+
channel_axis=channel_axis,
|
| 233 |
+
z_axis=z_axis,
|
| 234 |
+
anisotropy=args.anisotropy,
|
| 235 |
+
niter=args.niter,
|
| 236 |
+
flow3D_smooth=args.flow3D_smooth)
|
| 237 |
+
masks, flows = out[:2]
|
| 238 |
+
|
| 239 |
+
if args.exclude_on_edges:
|
| 240 |
+
masks = utils.remove_edge_masks(masks)
|
| 241 |
+
if not args.no_npy:
|
| 242 |
+
io.masks_flows_to_seg(image, masks, flows, image_name,
|
| 243 |
+
imgs_restore=None,
|
| 244 |
+
restore_type=None,
|
| 245 |
+
ratio=1.)
|
| 246 |
+
if saving_something:
|
| 247 |
+
suffix = "_cp_masks"
|
| 248 |
+
if args.output_name is not None:
|
| 249 |
+
# (1) If `savedir` is not defined, then must have a non-zero `suffix`
|
| 250 |
+
if args.savedir is None and len(args.output_name) > 0:
|
| 251 |
+
suffix = args.output_name
|
| 252 |
+
elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir):
|
| 253 |
+
# (2) If `savedir` is defined, and different from `dir` then
|
| 254 |
+
# takes the value passed as a param. (which can be empty string)
|
| 255 |
+
suffix = args.output_name
|
| 256 |
+
|
| 257 |
+
io.save_masks(image, masks, flows, image_name,
|
| 258 |
+
suffix=suffix, png=args.save_png,
|
| 259 |
+
tif=args.save_tif, save_flows=args.save_flows,
|
| 260 |
+
save_outlines=args.save_outlines,
|
| 261 |
+
dir_above=args.dir_above, savedir=args.savedir,
|
| 262 |
+
save_txt=args.save_txt, in_folders=args.in_folders,
|
| 263 |
+
save_mpl=args.save_mpl)
|
| 264 |
+
if args.save_rois:
|
| 265 |
+
io.save_rois(masks, image_name)
|
| 266 |
+
logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))
|
| 267 |
+
|
| 268 |
+
return model
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
main()
|
models/seg_post_model/cellpose/cli.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_arg_parser():
|
| 9 |
+
""" Parses command line arguments for cellpose main function
|
| 10 |
+
|
| 11 |
+
Note: this function has to be in a separate file to allow autodoc to work for CLI.
|
| 12 |
+
The autodoc_mock_imports in conf.py does not work for sphinx-argparse sometimes,
|
| 13 |
+
see https://github.com/ashb/sphinx-argparse/issues/9#issue-1097057823
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters")
|
| 17 |
+
|
| 18 |
+
# misc settings
|
| 19 |
+
parser.add_argument("--version", action="store_true",
|
| 20 |
+
help="show cellpose version info")
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--verbose", action="store_true",
|
| 23 |
+
help="show information about running and settings and save to log")
|
| 24 |
+
parser.add_argument("--Zstack", action="store_true", help="run GUI in 3D mode")
|
| 25 |
+
|
| 26 |
+
# settings for CPU vs GPU
|
| 27 |
+
hardware_args = parser.add_argument_group("Hardware Arguments")
|
| 28 |
+
hardware_args.add_argument("--use_gpu", action="store_true",
|
| 29 |
+
help="use gpu if torch with cuda installed")
|
| 30 |
+
hardware_args.add_argument(
|
| 31 |
+
"--gpu_device", required=False, default="0", type=str,
|
| 32 |
+
help="which gpu device to use, use an integer for torch, or mps for M1")
|
| 33 |
+
|
| 34 |
+
# settings for locating and formatting images
|
| 35 |
+
input_img_args = parser.add_argument_group("Input Image Arguments")
|
| 36 |
+
input_img_args.add_argument("--dir", default=[], type=str,
|
| 37 |
+
help="folder containing data to run or train on.")
|
| 38 |
+
input_img_args.add_argument(
|
| 39 |
+
"--image_path", default=[], type=str, help=
|
| 40 |
+
"if given and --dir not given, run on single image instead of folder (cannot train with this option)"
|
| 41 |
+
)
|
| 42 |
+
input_img_args.add_argument(
|
| 43 |
+
"--look_one_level_down", action="store_true",
|
| 44 |
+
help="run processing on all subdirectories of current folder")
|
| 45 |
+
input_img_args.add_argument("--img_filter", default=[], type=str,
|
| 46 |
+
help="end string for images to run on")
|
| 47 |
+
input_img_args.add_argument(
|
| 48 |
+
"--channel_axis", default=None, type=int,
|
| 49 |
+
help="axis of image which corresponds to image channels")
|
| 50 |
+
input_img_args.add_argument("--z_axis", default=None, type=int,
|
| 51 |
+
help="axis of image which corresponds to Z dimension")
|
| 52 |
+
|
| 53 |
+
# TODO: remove deprecated in future version
|
| 54 |
+
input_img_args.add_argument(
|
| 55 |
+
"--chan", default=0, type=int, help=
|
| 56 |
+
"Deprecated in v4.0.1+, not used. ")
|
| 57 |
+
input_img_args.add_argument(
|
| 58 |
+
"--chan2", default=0, type=int, help=
|
| 59 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 60 |
+
input_img_args.add_argument("--invert", action="store_true", help=
|
| 61 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 62 |
+
input_img_args.add_argument(
|
| 63 |
+
"--all_channels", action="store_true", help=
|
| 64 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 65 |
+
|
| 66 |
+
# model settings
|
| 67 |
+
model_args = parser.add_argument_group("Model Arguments")
|
| 68 |
+
model_args.add_argument("--pretrained_model", required=False, default="cpsam",
|
| 69 |
+
type=str,
|
| 70 |
+
help="model to use for running or starting training")
|
| 71 |
+
model_args.add_argument(
|
| 72 |
+
"--add_model", required=False, default=None, type=str,
|
| 73 |
+
help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
|
| 74 |
+
model_args.add_argument("--pretrained_model_ortho", required=False, default=None,
|
| 75 |
+
type=str,
|
| 76 |
+
help="Deprecated in v4.0.1+, not used. ")
|
| 77 |
+
|
| 78 |
+
# TODO: remove deprecated in future version
|
| 79 |
+
model_args.add_argument("--restore_type", required=False, default=None, type=str, help=
|
| 80 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 81 |
+
model_args.add_argument("--chan2_restore", action="store_true", help=
|
| 82 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 83 |
+
model_args.add_argument(
|
| 84 |
+
"--transformer", action="store_true", help=
|
| 85 |
+
"use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)")
|
| 86 |
+
|
| 87 |
+
# algorithm settings
|
| 88 |
+
algorithm_args = parser.add_argument_group("Algorithm Arguments")
|
| 89 |
+
algorithm_args.add_argument("--no_norm", action="store_true",
|
| 90 |
+
help="do not normalize images (normalize=False)")
|
| 91 |
+
algorithm_args.add_argument(
|
| 92 |
+
'--norm_percentile',
|
| 93 |
+
nargs=2, # Require exactly two values
|
| 94 |
+
metavar=('VALUE1', 'VALUE2'),
|
| 95 |
+
help="Provide two float values to set norm_percentile (e.g., --norm_percentile 1 99)"
|
| 96 |
+
)
|
| 97 |
+
algorithm_args.add_argument(
|
| 98 |
+
"--do_3D", action="store_true",
|
| 99 |
+
help="process images as 3D stacks of images (nplanes x nchan x Ly x Lx")
|
| 100 |
+
algorithm_args.add_argument(
|
| 101 |
+
"--diameter", required=False, default=None, type=float, help=
|
| 102 |
+
"use to resize cells to the training diameter (30 pixels)"
|
| 103 |
+
)
|
| 104 |
+
algorithm_args.add_argument(
|
| 105 |
+
"--stitch_threshold", required=False, default=0.0, type=float,
|
| 106 |
+
help="compute masks in 2D then stitch together masks with IoU>0.9 across planes"
|
| 107 |
+
)
|
| 108 |
+
algorithm_args.add_argument(
|
| 109 |
+
"--min_size", required=False, default=15, type=int,
|
| 110 |
+
help="minimum number of pixels per mask, can turn off with -1")
|
| 111 |
+
algorithm_args.add_argument(
|
| 112 |
+
"--flow3D_smooth", required=False, default=0, type=float,
|
| 113 |
+
help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing")
|
| 114 |
+
algorithm_args.add_argument(
|
| 115 |
+
"--flow_threshold", default=0.4, type=float, help=
|
| 116 |
+
"flow error threshold, 0 turns off this optional QC step. Default: %(default)s")
|
| 117 |
+
algorithm_args.add_argument(
|
| 118 |
+
"--cellprob_threshold", default=0, type=float,
|
| 119 |
+
help="cellprob threshold, default is 0, decrease to find more and larger masks")
|
| 120 |
+
algorithm_args.add_argument(
|
| 121 |
+
"--niter", default=0, type=int, help=
|
| 122 |
+
"niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs"
|
| 123 |
+
)
|
| 124 |
+
algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
|
| 125 |
+
help="anisotropy of volume in 3D")
|
| 126 |
+
algorithm_args.add_argument("--exclude_on_edges", action="store_true",
|
| 127 |
+
help="discard masks which touch edges of image")
|
| 128 |
+
algorithm_args.add_argument(
|
| 129 |
+
"--augment", action="store_true",
|
| 130 |
+
help="tiles image with overlapping tiles and flips overlapped regions to augment"
|
| 131 |
+
)
|
| 132 |
+
algorithm_args.add_argument("--batch_size", default=8, type=int,
|
| 133 |
+
help="inference batch size. Default: %(default)s")
|
| 134 |
+
|
| 135 |
+
# TODO: remove deprecated in future version
|
| 136 |
+
algorithm_args.add_argument(
|
| 137 |
+
"--no_resample", action="store_true",
|
| 138 |
+
help="disables flows/cellprob resampling to original image size before computing masks. Using this flag will make more masks more jagged with larger diameter settings.")
|
| 139 |
+
algorithm_args.add_argument(
|
| 140 |
+
"--no_interp", action="store_true",
|
| 141 |
+
help="do not interpolate when running dynamics (was default)")
|
| 142 |
+
|
| 143 |
+
# output settings
|
| 144 |
+
output_args = parser.add_argument_group("Output Arguments")
|
| 145 |
+
output_args.add_argument(
|
| 146 |
+
"--save_png", action="store_true",
|
| 147 |
+
help="save masks as png")
|
| 148 |
+
output_args.add_argument(
|
| 149 |
+
"--save_tif", action="store_true",
|
| 150 |
+
help="save masks as tif")
|
| 151 |
+
output_args.add_argument(
|
| 152 |
+
"--output_name", default=None, type=str,
|
| 153 |
+
help="suffix for saved masks, default is _cp_masks, can be empty if `savedir` used and different of `dir`")
|
| 154 |
+
output_args.add_argument("--no_npy", action="store_true",
|
| 155 |
+
help="suppress saving of npy")
|
| 156 |
+
output_args.add_argument(
|
| 157 |
+
"--savedir", default=None, type=str, help=
|
| 158 |
+
"folder to which segmentation results will be saved (defaults to input image directory)"
|
| 159 |
+
)
|
| 160 |
+
output_args.add_argument(
|
| 161 |
+
"--dir_above", action="store_true", help=
|
| 162 |
+
"save output folders adjacent to image folder instead of inside it (off by default)"
|
| 163 |
+
)
|
| 164 |
+
output_args.add_argument("--in_folders", action="store_true",
|
| 165 |
+
help="flag to save output in folders (off by default)")
|
| 166 |
+
output_args.add_argument(
|
| 167 |
+
"--save_flows", action="store_true", help=
|
| 168 |
+
"whether or not to save RGB images of flows when masks are saved (disabled by default)"
|
| 169 |
+
)
|
| 170 |
+
output_args.add_argument(
|
| 171 |
+
"--save_outlines", action="store_true", help=
|
| 172 |
+
"whether or not to save RGB outline images when masks are saved (disabled by default)"
|
| 173 |
+
)
|
| 174 |
+
output_args.add_argument(
|
| 175 |
+
"--save_rois", action="store_true",
|
| 176 |
+
help="whether or not to save ImageJ compatible ROI archive (disabled by default)"
|
| 177 |
+
)
|
| 178 |
+
output_args.add_argument(
|
| 179 |
+
"--save_txt", action="store_true",
|
| 180 |
+
help="flag to enable txt outlines for ImageJ (disabled by default)")
|
| 181 |
+
output_args.add_argument(
|
| 182 |
+
"--save_mpl", action="store_true",
|
| 183 |
+
help="save a figure of image/mask/flows using matplotlib (disabled by default). "
|
| 184 |
+
"This is slow, especially with large images.")
|
| 185 |
+
|
| 186 |
+
# training settings
|
| 187 |
+
training_args = parser.add_argument_group("Training Arguments")
|
| 188 |
+
training_args.add_argument("--train", action="store_true",
|
| 189 |
+
help="train network using images in dir")
|
| 190 |
+
training_args.add_argument("--test_dir", default=[], type=str,
|
| 191 |
+
help="folder containing test data (optional)")
|
| 192 |
+
training_args.add_argument(
|
| 193 |
+
"--file_list", default=[], type=str, help=
|
| 194 |
+
"path to list of files for training and testing and probabilities for each image (optional)"
|
| 195 |
+
)
|
| 196 |
+
training_args.add_argument(
|
| 197 |
+
"--mask_filter", default="_masks", type=str, help=
|
| 198 |
+
"end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s"
|
| 199 |
+
)
|
| 200 |
+
training_args.add_argument("--learning_rate", default=1e-5, type=float,
|
| 201 |
+
help="learning rate. Default: %(default)s")
|
| 202 |
+
training_args.add_argument("--weight_decay", default=0.1, type=float,
|
| 203 |
+
help="weight decay. Default: %(default)s")
|
| 204 |
+
training_args.add_argument("--n_epochs", default=100, type=int,
|
| 205 |
+
help="number of epochs. Default: %(default)s")
|
| 206 |
+
training_args.add_argument("--train_batch_size", default=1, type=int,
|
| 207 |
+
help="training batch size. Default: %(default)s")
|
| 208 |
+
training_args.add_argument("--bsize", default=256, type=int,
|
| 209 |
+
help="block size for tiles. Default: %(default)s")
|
| 210 |
+
training_args.add_argument(
|
| 211 |
+
"--nimg_per_epoch", default=None, type=int,
|
| 212 |
+
help="number of train images per epoch. Default is to use all train images.")
|
| 213 |
+
training_args.add_argument(
|
| 214 |
+
"--nimg_test_per_epoch", default=None, type=int,
|
| 215 |
+
help="number of test images per epoch. Default is to use all test images.")
|
| 216 |
+
training_args.add_argument(
|
| 217 |
+
"--min_train_masks", default=5, type=int, help=
|
| 218 |
+
"minimum number of masks a training image must have to be used. Default: %(default)s"
|
| 219 |
+
)
|
| 220 |
+
training_args.add_argument("--SGD", default=0, type=int,
|
| 221 |
+
help="Deprecated in v4.0.1+, not used - AdamW used instead. ")
|
| 222 |
+
training_args.add_argument(
|
| 223 |
+
"--save_every", default=100, type=int,
|
| 224 |
+
help="number of epochs to skip between saves. Default: %(default)s")
|
| 225 |
+
training_args.add_argument(
|
| 226 |
+
"--save_each", action="store_true",
|
| 227 |
+
help="wether or not to save each epoch. Must also use --save_every. (default: False)")
|
| 228 |
+
training_args.add_argument(
|
| 229 |
+
"--model_name_out", default=None, type=str,
|
| 230 |
+
help="Name of model to save as, defaults to name describing model architecture. "
|
| 231 |
+
"Model is saved in the folder specified by --dir in models subfolder.")
|
| 232 |
+
|
| 233 |
+
# TODO: remove deprecated in future version
|
| 234 |
+
training_args.add_argument(
|
| 235 |
+
"--diam_mean", default=30., type=float, help=
|
| 236 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 237 |
+
training_args.add_argument("--train_size", action="store_true", help=
|
| 238 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 239 |
+
|
| 240 |
+
return parser
|
models/seg_post_model/cellpose/core.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import trange
|
| 7 |
+
from . import transforms, utils
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
TORCH_ENABLED = True
|
| 12 |
+
|
| 13 |
+
core_logger = logging.getLogger(__name__)
|
| 14 |
+
tqdm_out = utils.TqdmToLogger(core_logger, level=logging.INFO)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def use_gpu(gpu_number=0, use_torch=True):
|
| 18 |
+
"""
|
| 19 |
+
Check if GPU is available for use.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
gpu_number (int): The index of the GPU to be used. Default is 0.
|
| 23 |
+
use_torch (bool): Whether to use PyTorch for GPU check. Default is True.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
bool: True if GPU is available, False otherwise.
|
| 27 |
+
|
| 28 |
+
Raises:
|
| 29 |
+
ValueError: If use_torch is False, as cellpose only runs with PyTorch now.
|
| 30 |
+
"""
|
| 31 |
+
if use_torch:
|
| 32 |
+
return _use_gpu_torch(gpu_number)
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError("cellpose only runs with PyTorch now")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _use_gpu_torch(gpu_number=0):
|
| 38 |
+
"""
|
| 39 |
+
Checks if CUDA or MPS is available and working with PyTorch.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
gpu_number (int): The GPU device number to use (default is 0).
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
bool: True if CUDA or MPS is available and working, False otherwise.
|
| 46 |
+
"""
|
| 47 |
+
try:
|
| 48 |
+
device = torch.device("cuda:" + str(gpu_number))
|
| 49 |
+
_ = torch.zeros((1,1)).to(device)
|
| 50 |
+
core_logger.info("** TORCH CUDA version installed and working. **")
|
| 51 |
+
return True
|
| 52 |
+
except:
|
| 53 |
+
pass
|
| 54 |
+
try:
|
| 55 |
+
device = torch.device('mps:' + str(gpu_number))
|
| 56 |
+
_ = torch.zeros((1,1)).to(device)
|
| 57 |
+
core_logger.info('** TORCH MPS version installed and working. **')
|
| 58 |
+
return True
|
| 59 |
+
except:
|
| 60 |
+
core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.')
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def assign_device(use_torch=True, gpu=False, device=0):
|
| 65 |
+
"""
|
| 66 |
+
Assigns the device (CPU or GPU or mps) to be used for computation.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
use_torch (bool, optional): Whether to use torch for GPU detection. Defaults to True.
|
| 70 |
+
gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
|
| 71 |
+
device (int or str, optional): The device index or name to be used. Defaults to 0.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
torch.device, bool (True if GPU is used, False otherwise)
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
if isinstance(device, str):
|
| 78 |
+
if device != "mps" or not(gpu and torch.backends.mps.is_available()):
|
| 79 |
+
device = int(device)
|
| 80 |
+
if gpu and use_gpu(use_torch=True):
|
| 81 |
+
try:
|
| 82 |
+
if torch.cuda.is_available():
|
| 83 |
+
device = torch.device(f'cuda:{device}')
|
| 84 |
+
core_logger.info(">>>> using GPU (CUDA)")
|
| 85 |
+
gpu = True
|
| 86 |
+
cpu = False
|
| 87 |
+
except:
|
| 88 |
+
gpu = False
|
| 89 |
+
cpu = True
|
| 90 |
+
try:
|
| 91 |
+
if torch.backends.mps.is_available():
|
| 92 |
+
device = torch.device('mps')
|
| 93 |
+
core_logger.info(">>>> using GPU (MPS)")
|
| 94 |
+
gpu = True
|
| 95 |
+
cpu = False
|
| 96 |
+
except:
|
| 97 |
+
gpu = False
|
| 98 |
+
cpu = True
|
| 99 |
+
else:
|
| 100 |
+
device = torch.device('cpu')
|
| 101 |
+
core_logger.info('>>>> using CPU')
|
| 102 |
+
gpu = False
|
| 103 |
+
cpu = True
|
| 104 |
+
|
| 105 |
+
if cpu:
|
| 106 |
+
device = torch.device("cpu")
|
| 107 |
+
core_logger.info(">>>> using CPU")
|
| 108 |
+
gpu = False
|
| 109 |
+
return device, gpu
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _to_device(x, device, dtype=torch.float32):
|
| 113 |
+
"""
|
| 114 |
+
Converts the input tensor or numpy array to the specified device.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
x (torch.Tensor or numpy.ndarray): The input tensor or numpy array.
|
| 118 |
+
device (torch.device): The target device.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
torch.Tensor: The converted tensor on the specified device.
|
| 122 |
+
"""
|
| 123 |
+
if not isinstance(x, torch.Tensor):
|
| 124 |
+
X = torch.from_numpy(x).to(device, dtype=dtype)
|
| 125 |
+
return X
|
| 126 |
+
else:
|
| 127 |
+
return x
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _from_device(X):
|
| 131 |
+
"""
|
| 132 |
+
Converts a PyTorch tensor from the device to a NumPy array on the CPU.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
X (torch.Tensor): The input PyTorch tensor.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
numpy.ndarray: The converted NumPy array.
|
| 139 |
+
"""
|
| 140 |
+
# The cast is so numpy conversion always works
|
| 141 |
+
x = X.detach().cpu().to(torch.float32).numpy()
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _forward(net, x, feat=None):
|
| 146 |
+
"""Converts images to torch tensors, runs the network model, and returns numpy arrays.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
net (torch.nn.Module): The network model.
|
| 150 |
+
x (numpy.ndarray): The input images.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
|
| 154 |
+
"""
|
| 155 |
+
X = _to_device(x, device=net.device, dtype=net.dtype)
|
| 156 |
+
if feat is not None:
|
| 157 |
+
feat = _to_device(feat, device=net.device, dtype=net.dtype)
|
| 158 |
+
net.eval()
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
y, style = net(X, feat=feat)[:2]
|
| 161 |
+
del X
|
| 162 |
+
y = _from_device(y)
|
| 163 |
+
style = _from_device(style)
|
| 164 |
+
return y, style
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def run_net(net, imgi, feat=None, batch_size=8, augment=False, tile_overlap=0.1, bsize=224,
|
| 168 |
+
rsz=None):
|
| 169 |
+
"""
|
| 170 |
+
Run network on stack of images.
|
| 171 |
+
|
| 172 |
+
(faster if augment is False)
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
net (class): cellpose network (model.net)
|
| 176 |
+
imgi (np.ndarray): The input image or stack of images of size [Lz x Ly x Lx x nchan].
|
| 177 |
+
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
|
| 178 |
+
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
|
| 179 |
+
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
|
| 180 |
+
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 181 |
+
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
|
| 185 |
+
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
|
| 186 |
+
style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
|
| 187 |
+
"""
|
| 188 |
+
# run network
|
| 189 |
+
Lz, Ly0, Lx0, nchan = imgi.shape
|
| 190 |
+
if rsz is not None:
|
| 191 |
+
if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
|
| 192 |
+
rsz = [rsz, rsz]
|
| 193 |
+
Lyr, Lxr = int(Ly0 * rsz[0]), int(Lx0 * rsz[1])
|
| 194 |
+
else:
|
| 195 |
+
Lyr, Lxr = Ly0, Lx0 # 512, 512
|
| 196 |
+
|
| 197 |
+
ly, lx = bsize, bsize # 256, 256
|
| 198 |
+
ypad1, ypad2, xpad1, xpad2 = transforms.get_pad_yx(Lyr, Lxr, min_size=(bsize, bsize)) # 8
|
| 199 |
+
Ly, Lx = Lyr + ypad1 + ypad2, Lxr + xpad1 + xpad2 # 528, 528
|
| 200 |
+
pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]])
|
| 201 |
+
|
| 202 |
+
if augment:
|
| 203 |
+
ny = max(2, int(np.ceil(2. * Ly / bsize)))
|
| 204 |
+
nx = max(2, int(np.ceil(2. * Lx / bsize)))
|
| 205 |
+
else:
|
| 206 |
+
ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize)) # 3
|
| 207 |
+
nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize)) # 3
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# run multiple slices at the same time
|
| 211 |
+
ntiles = ny * nx
|
| 212 |
+
nimgs = max(1, batch_size // ntiles) # number of imgs to run in the same batch, 1
|
| 213 |
+
niter = int(np.ceil(Lz / nimgs)) # 1
|
| 214 |
+
ziterator = (trange(niter, file=tqdm_out, mininterval=30)
|
| 215 |
+
if niter > 10 or Lz > 1 else range(niter))
|
| 216 |
+
for k in ziterator:
|
| 217 |
+
inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs))
|
| 218 |
+
IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 3, 256, 256
|
| 219 |
+
if feat is not None:
|
| 220 |
+
FEATa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 256
|
| 221 |
+
else:
|
| 222 |
+
FEATa = None
|
| 223 |
+
for i, b in enumerate(inds):
|
| 224 |
+
# pad image for net so Ly and Lx are divisible by 4
|
| 225 |
+
imgb = transforms.resize_image(imgi[b], rsz=rsz) if rsz is not None else imgi[b].copy()
|
| 226 |
+
imgb = np.pad(imgb.transpose(2,0,1), pads, mode="constant") # 3, 528, 528
|
| 227 |
+
|
| 228 |
+
IMG, ysub, xsub, Lyt, Lxt = transforms.make_tiles(
|
| 229 |
+
imgb, bsize=bsize, augment=augment,
|
| 230 |
+
tile_overlap=tile_overlap) # IMG: 3, 3, 3, 256, 256
|
| 231 |
+
IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG,
|
| 232 |
+
(ny * nx, nchan, ly, lx))
|
| 233 |
+
if feat is not None:
|
| 234 |
+
featb = transforms.resize_image(feat[b], rsz=rsz) if rsz is not None else feat[b].copy()
|
| 235 |
+
featb = np.pad(featb.transpose(2,0,1), pads, mode="constant")
|
| 236 |
+
FEAT, ysub, xsub, Lyt, Lxt = transforms.make_tiles(
|
| 237 |
+
featb, bsize=bsize, augment=augment,
|
| 238 |
+
tile_overlap=tile_overlap)
|
| 239 |
+
FEATa[i * ntiles : (i+1) * ntiles] = np.reshape(FEAT,
|
| 240 |
+
(ny * nx, nchan, ly, lx))
|
| 241 |
+
|
| 242 |
+
# run network
|
| 243 |
+
for j in range(0, IMGa.shape[0], batch_size):
|
| 244 |
+
bslc = slice(j, min(j + batch_size, IMGa.shape[0]))
|
| 245 |
+
ya0, stylea0 = _forward(net, IMGa[bslc], feat=FEATa[bslc] if FEATa is not None else None)
|
| 246 |
+
if j == 0:
|
| 247 |
+
nout = ya0.shape[1]
|
| 248 |
+
ya = np.zeros((IMGa.shape[0], nout, ly, lx), "float32")
|
| 249 |
+
stylea = np.zeros((IMGa.shape[0], 256), "float32")
|
| 250 |
+
ya[bslc] = ya0
|
| 251 |
+
stylea[bslc] = stylea0
|
| 252 |
+
|
| 253 |
+
# average tiles
|
| 254 |
+
for i, b in enumerate(inds):
|
| 255 |
+
if i==0 and k==0:
|
| 256 |
+
yf = np.zeros((Lz, nout, Ly, Lx), "float32")
|
| 257 |
+
styles = np.zeros((Lz, 256), "float32")
|
| 258 |
+
y = ya[i * ntiles : (i + 1) * ntiles]
|
| 259 |
+
if augment:
|
| 260 |
+
y = np.reshape(y, (ny, nx, 3, ly, lx))
|
| 261 |
+
y = transforms.unaugment_tiles(y)
|
| 262 |
+
y = np.reshape(y, (-1, 3, ly, lx))
|
| 263 |
+
yfi = transforms.average_tiles(y, ysub, xsub, Lyt, Lxt)
|
| 264 |
+
yf[b] = yfi[:, :imgb.shape[-2], :imgb.shape[-1]]
|
| 265 |
+
stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0)
|
| 266 |
+
stylei /= (stylei**2).sum()**0.5
|
| 267 |
+
styles[b] = stylei
|
| 268 |
+
# slices from padding
|
| 269 |
+
yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2]
|
| 270 |
+
yf = yf.transpose(0,2,3,1)
|
| 271 |
+
return yf, np.array(styles)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def run_3D(net, imgs, batch_size=8, augment=False,
|
| 275 |
+
tile_overlap=0.1, bsize=224, net_ortho=None,
|
| 276 |
+
progress=None):
|
| 277 |
+
"""
|
| 278 |
+
Run network on image z-stack.
|
| 279 |
+
|
| 280 |
+
(faster if augment is False)
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan].
|
| 284 |
+
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
|
| 285 |
+
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
|
| 286 |
+
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
|
| 287 |
+
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
|
| 288 |
+
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 289 |
+
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
|
| 290 |
+
net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None.
|
| 291 |
+
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
|
| 295 |
+
y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability.
|
| 296 |
+
style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
|
| 297 |
+
"""
|
| 298 |
+
sstr = ["YX", "ZY", "ZX"]
|
| 299 |
+
pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)]
|
| 300 |
+
ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)]
|
| 301 |
+
cp = [(1, 2), (0, 2), (0, 1)]
|
| 302 |
+
cpy = [(0, 1), (0, 1), (0, 1)]
|
| 303 |
+
shape = imgs.shape[:-1]
|
| 304 |
+
yf = np.zeros((*shape, 4), "float32")
|
| 305 |
+
for p in range(3):
|
| 306 |
+
xsl = imgs.transpose(pm[p])
|
| 307 |
+
# per image
|
| 308 |
+
core_logger.info("running %s: %d planes of size (%d, %d)" %
|
| 309 |
+
(sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]]))
|
| 310 |
+
y, style = run_net(net,
|
| 311 |
+
xsl, batch_size=batch_size, augment=augment,
|
| 312 |
+
bsize=bsize, tile_overlap=tile_overlap,
|
| 313 |
+
rsz=None)
|
| 314 |
+
yf[..., -1] += y[..., -1].transpose(ipm[p])
|
| 315 |
+
for j in range(2):
|
| 316 |
+
yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p])
|
| 317 |
+
y = None; del y
|
| 318 |
+
|
| 319 |
+
if progress is not None:
|
| 320 |
+
progress.setValue(25 + 15 * p)
|
| 321 |
+
|
| 322 |
+
return yf, style
|
models/seg_post_model/cellpose/denoise.py
ADDED
|
@@ -0,0 +1,1474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import os, time, datetime
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.stats import mode
|
| 7 |
+
import cv2
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn.functional import conv2d, interpolate
|
| 11 |
+
from tqdm import trange
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
denoise_logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
from cellpose import transforms, utils, io
|
| 19 |
+
from cellpose.core import run_net
|
| 20 |
+
from cellpose.models import CellposeModel, model_path, normalize_default, assign_device
|
| 21 |
+
|
| 22 |
+
MODEL_NAMES = []
|
| 23 |
+
for ctype in ["cyto3", "cyto2", "nuclei"]:
|
| 24 |
+
for ntype in ["denoise", "deblur", "upsample", "oneclick"]:
|
| 25 |
+
MODEL_NAMES.append(f"{ntype}_{ctype}")
|
| 26 |
+
if ctype != "cyto3":
|
| 27 |
+
for ltype in ["per", "seg", "rec"]:
|
| 28 |
+
MODEL_NAMES.append(f"{ntype}_{ltype}_{ctype}")
|
| 29 |
+
if ctype != "cyto3":
|
| 30 |
+
MODEL_NAMES.append(f"aniso_{ctype}")
|
| 31 |
+
|
| 32 |
+
criterion = nn.MSELoss(reduction="mean")
|
| 33 |
+
criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def deterministic(seed=0):
|
| 37 |
+
""" set random seeds to create test data """
|
| 38 |
+
import random
|
| 39 |
+
torch.manual_seed(seed)
|
| 40 |
+
torch.cuda.manual_seed(seed)
|
| 41 |
+
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
| 42 |
+
np.random.seed(seed) # Numpy module.
|
| 43 |
+
random.seed(seed) # Python random module.
|
| 44 |
+
torch.manual_seed(seed)
|
| 45 |
+
torch.backends.cudnn.benchmark = False
|
| 46 |
+
torch.backends.cudnn.deterministic = True
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def loss_fn_rec(lbl, y):
|
| 50 |
+
""" loss function between true labels lbl and prediction y """
|
| 51 |
+
loss = 80. * criterion(y, lbl)
|
| 52 |
+
return loss
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def loss_fn_seg(lbl, y):
|
| 56 |
+
""" loss function between true labels lbl and prediction y """
|
| 57 |
+
veci = 5. * lbl[:, 1:]
|
| 58 |
+
lbl = (lbl[:, 0] > .5).float()
|
| 59 |
+
loss = criterion(y[:, :2], veci)
|
| 60 |
+
loss /= 2.
|
| 61 |
+
loss2 = criterion2(y[:, 2], lbl)
|
| 62 |
+
loss = loss + loss2
|
| 63 |
+
return loss
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_sigma(Tdown):
|
| 67 |
+
""" Calculates the correlation matrices across channels for the perceptual loss.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
Tdown (list): List of tensors output by each downsampling block of network.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
list: List of correlations for each input tensor.
|
| 74 |
+
"""
|
| 75 |
+
Tnorm = [x - x.mean((-2, -1), keepdim=True) for x in Tdown]
|
| 76 |
+
Tnorm = [x / x.std((-2, -1), keepdim=True) for x in Tnorm]
|
| 77 |
+
Sigma = [
|
| 78 |
+
torch.einsum("bnxy, bmxy -> bnm", x, x) / (x.shape[-2] * x.shape[-1])
|
| 79 |
+
for x in Tnorm
|
| 80 |
+
]
|
| 81 |
+
return Sigma
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def imstats(X, net1):
|
| 85 |
+
"""
|
| 86 |
+
Calculates the image correlation matrices for the perceptual loss.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
X (torch.Tensor): Input image tensor.
|
| 90 |
+
net1: Cellpose net.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
list: A list of tensors of correlation matrices.
|
| 94 |
+
"""
|
| 95 |
+
_, _, Tdown = net1(X)
|
| 96 |
+
Sigma = get_sigma(Tdown)
|
| 97 |
+
Sigma = [x.detach() for x in Sigma]
|
| 98 |
+
return Sigma
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def loss_fn_per(img, net1, yl):
|
| 102 |
+
"""
|
| 103 |
+
Calculates the perceptual loss function for image restoration.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
img (torch.Tensor): Input image tensor (noisy/blurry/downsampled).
|
| 107 |
+
net1 (torch.nn.Module): Perceptual loss net (Cellpose segmentation net).
|
| 108 |
+
yl (torch.Tensor): Clean image tensor.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
torch.Tensor: Mean perceptual loss.
|
| 112 |
+
"""
|
| 113 |
+
Sigma = imstats(img, net1)
|
| 114 |
+
sd = [x.std((1, 2)) + 1e-6 for x in Sigma]
|
| 115 |
+
Sigma_test = get_sigma(yl)
|
| 116 |
+
losses = torch.zeros(len(Sigma[0]), device=img.device)
|
| 117 |
+
for k in range(len(Sigma)):
|
| 118 |
+
losses = losses + (((Sigma_test[k] - Sigma[k])**2).mean((1, 2)) / sd[k]**2)
|
| 119 |
+
return losses.mean()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
|
| 123 |
+
"""
|
| 124 |
+
Calculates the test loss for image restoration tasks.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
net0 (torch.nn.Module): The image restoration network.
|
| 128 |
+
X (torch.Tensor): The input image tensor.
|
| 129 |
+
net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
|
| 130 |
+
img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
|
| 131 |
+
lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
|
| 132 |
+
lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
tuple: A tuple containing the total loss and the perceptual loss.
|
| 136 |
+
"""
|
| 137 |
+
net0.eval()
|
| 138 |
+
if net1 is not None:
|
| 139 |
+
net1.eval()
|
| 140 |
+
loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
|
| 141 |
+
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
img_dn = net0(X)[0]
|
| 144 |
+
if lam[2] > 0.:
|
| 145 |
+
loss += lam[2] * loss_fn_rec(img, img_dn)
|
| 146 |
+
if lam[1] > 0. or lam[0] > 0.:
|
| 147 |
+
y, _, ydown = net1(img_dn)
|
| 148 |
+
if lam[1] > 0.:
|
| 149 |
+
loss += lam[1] * loss_fn_seg(lbl, y)
|
| 150 |
+
if lam[0] > 0.:
|
| 151 |
+
loss_per = loss_fn_per(img, net1, ydown)
|
| 152 |
+
loss += lam[0] * loss_per
|
| 153 |
+
return loss, loss_per
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def train_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
|
| 157 |
+
"""
|
| 158 |
+
Calculates the train loss for image restoration tasks.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
net0 (torch.nn.Module): The image restoration network.
|
| 162 |
+
X (torch.Tensor): The input image tensor.
|
| 163 |
+
net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
|
| 164 |
+
img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
|
| 165 |
+
lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
|
| 166 |
+
lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
tuple: A tuple containing the total loss and the perceptual loss.
|
| 170 |
+
"""
|
| 171 |
+
net0.train()
|
| 172 |
+
if net1 is not None:
|
| 173 |
+
net1.eval()
|
| 174 |
+
loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
|
| 175 |
+
|
| 176 |
+
img_dn = net0(X)[0]
|
| 177 |
+
if lam[2] > 0.:
|
| 178 |
+
loss += lam[2] * loss_fn_rec(img, img_dn)
|
| 179 |
+
if lam[1] > 0. or lam[0] > 0.:
|
| 180 |
+
y, _, ydown = net1(img_dn)
|
| 181 |
+
if lam[1] > 0.:
|
| 182 |
+
loss += lam[1] * loss_fn_seg(lbl, y)
|
| 183 |
+
if lam[0] > 0.:
|
| 184 |
+
loss_per = loss_fn_per(img, net1, ydown)
|
| 185 |
+
loss += lam[0] * loss_per
|
| 186 |
+
return loss, loss_per
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def img_norm(imgi):
|
| 190 |
+
"""
|
| 191 |
+
Normalizes the input image by subtracting the 1st percentile and dividing by the difference between the 99th and 1st percentiles.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
imgi (torch.Tensor): Input image tensor.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
torch.Tensor: Normalized image tensor.
|
| 198 |
+
"""
|
| 199 |
+
shape = imgi.shape
|
| 200 |
+
imgi = imgi.reshape(imgi.shape[0], imgi.shape[1], -1)
|
| 201 |
+
perc = torch.quantile(imgi, torch.tensor([0.01, 0.99], device=imgi.device), dim=-1,
|
| 202 |
+
keepdim=True)
|
| 203 |
+
for k in range(imgi.shape[1]):
|
| 204 |
+
hask = (perc[1, :, k, 0] - perc[0, :, k, 0]) > 1e-3
|
| 205 |
+
imgi[hask, k] -= perc[0, hask, k]
|
| 206 |
+
imgi[hask, k] /= (perc[1, hask, k] - perc[0, hask, k])
|
| 207 |
+
imgi = imgi.reshape(shape)
|
| 208 |
+
return imgi
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7,
|
| 212 |
+
ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None,
|
| 213 |
+
ds=None, uniform_blur=False, partial_blur=False):
|
| 214 |
+
"""Adds noise to the input image.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
lbl (torch.Tensor): The input image tensor of shape (nimg, nchan, Ly, Lx).
|
| 218 |
+
alpha (float, optional): The shape parameter of the gamma distribution used for generating poisson noise. Defaults to 4.
|
| 219 |
+
beta (float, optional): The rate parameter of the gamma distribution used for generating poisson noise. Defaults to 0.7.
|
| 220 |
+
poisson (float, optional): The probability of adding poisson noise to the image. Defaults to 0.7.
|
| 221 |
+
blur (float, optional): The probability of adding gaussian blur to the image. Defaults to 0.7.
|
| 222 |
+
gblur (float, optional): The scale factor for the gaussian blur. Defaults to 1.0.
|
| 223 |
+
downsample (float, optional): The probability of downsampling the image. Defaults to 0.7.
|
| 224 |
+
ds_max (int, optional): The maximum downsampling factor. Defaults to 7.
|
| 225 |
+
diams (torch.Tensor, optional): The diameter of the objects in the image. Defaults to None.
|
| 226 |
+
pscale (torch.Tensor, optional): The scale factor for the poisson noise, instead of sampling. Defaults to None.
|
| 227 |
+
iso (bool, optional): Whether to use isotropic gaussian blur. Defaults to True.
|
| 228 |
+
sigma0 (torch.Tensor, optional): The standard deviation of the gaussian filter for the Y axis, instead of sampling. Defaults to None.
|
| 229 |
+
sigma1 (torch.Tensor, optional): The standard deviation of the gaussian filter for the X axis, instead of sampling. Defaults to None.
|
| 230 |
+
ds (torch.Tensor, optional): The downsampling factor for each image, instead of sampling. Defaults to None.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
torch.Tensor: The noisy image tensor of the same shape as the input image.
|
| 234 |
+
"""
|
| 235 |
+
device = lbl.device
|
| 236 |
+
imgi = torch.zeros_like(lbl)
|
| 237 |
+
Ly, Lx = lbl.shape[-2:]
|
| 238 |
+
|
| 239 |
+
diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device)
|
| 240 |
+
#ds0 = 1 if ds is None else ds.item()
|
| 241 |
+
ds = ds * torch.ones(
|
| 242 |
+
(len(lbl),), device=device, dtype=torch.long) if ds is not None else ds
|
| 243 |
+
|
| 244 |
+
# downsample
|
| 245 |
+
ii = []
|
| 246 |
+
idownsample = np.random.rand(len(lbl)) < downsample
|
| 247 |
+
if (ds is None and idownsample.sum() > 0.) or not iso:
|
| 248 |
+
ds = torch.ones(len(lbl), dtype=torch.long, device=device)
|
| 249 |
+
ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),),
|
| 250 |
+
device=device)
|
| 251 |
+
ii = torch.nonzero(ds > 1).flatten()
|
| 252 |
+
elif ds is not None and (ds > 1).sum():
|
| 253 |
+
ii = torch.nonzero(ds > 1).flatten()
|
| 254 |
+
|
| 255 |
+
# add gaussian blur
|
| 256 |
+
iblur = torch.rand(len(lbl), device=device) < blur
|
| 257 |
+
iblur[ii] = True
|
| 258 |
+
if iblur.sum() > 0:
|
| 259 |
+
if sigma0 is None:
|
| 260 |
+
if uniform_blur and iso:
|
| 261 |
+
xr = torch.rand(len(lbl), device=device)
|
| 262 |
+
if len(ii) > 0:
|
| 263 |
+
xr[ii] = ds[ii].float() / 2. / gblur
|
| 264 |
+
sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur])
|
| 265 |
+
sigma1 = sigma0.clone()
|
| 266 |
+
elif not iso:
|
| 267 |
+
xr = torch.rand(len(lbl), device=device)
|
| 268 |
+
if len(ii) > 0:
|
| 269 |
+
xr[ii] = (ds[ii].float()) / gblur
|
| 270 |
+
xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35
|
| 271 |
+
xr[ii] = torch.clip(xr[ii], 0.05, 1.5)
|
| 272 |
+
sigma0 = diams[iblur] / 30. * gblur * xr[iblur]
|
| 273 |
+
sigma1 = sigma0.clone() / 10.
|
| 274 |
+
else:
|
| 275 |
+
xrand = np.random.exponential(1, size=iblur.sum())
|
| 276 |
+
xrand = np.clip(xrand * 0.5, 0.1, 1.0)
|
| 277 |
+
xrand *= gblur
|
| 278 |
+
sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to(
|
| 279 |
+
device)
|
| 280 |
+
sigma1 = sigma0.clone()
|
| 281 |
+
else:
|
| 282 |
+
sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device)
|
| 283 |
+
sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device)
|
| 284 |
+
|
| 285 |
+
# create gaussian filter
|
| 286 |
+
xr = max(8, sigma0.max().long() * 2)
|
| 287 |
+
gfilt0 = torch.exp(-torch.arange(-xr + 1, xr, device=device)**2 /
|
| 288 |
+
(2 * sigma0.unsqueeze(-1)**2))
|
| 289 |
+
gfilt0 /= gfilt0.sum(axis=-1, keepdims=True)
|
| 290 |
+
gfilt1 = torch.zeros_like(gfilt0)
|
| 291 |
+
gfilt1[sigma1 == sigma0] = gfilt0[sigma1 == sigma0]
|
| 292 |
+
gfilt1[sigma1 != sigma0] = torch.exp(
|
| 293 |
+
-torch.arange(-xr + 1, xr, device=device)**2 /
|
| 294 |
+
(2 * sigma1[sigma1 != sigma0].unsqueeze(-1)**2))
|
| 295 |
+
gfilt1[sigma1 == 0] = 0.
|
| 296 |
+
gfilt1[sigma1 == 0, xr] = 1.
|
| 297 |
+
gfilt1 /= gfilt1.sum(axis=-1, keepdims=True)
|
| 298 |
+
gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1)
|
| 299 |
+
gfilt /= gfilt.sum(axis=(1, 2), keepdims=True)
|
| 300 |
+
|
| 301 |
+
lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1),
|
| 302 |
+
padding=gfilt.shape[-1] // 2,
|
| 303 |
+
groups=gfilt.shape[0]).transpose(1, 0)
|
| 304 |
+
if partial_blur:
|
| 305 |
+
#yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100)
|
| 306 |
+
imgi[iblur] = lbl[iblur].clone()
|
| 307 |
+
Lxc = int(Lx * 0.85)
|
| 308 |
+
ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32),
|
| 309 |
+
torch.arange(0, Lxc, dtype=torch.float32),
|
| 310 |
+
indexing="ij")
|
| 311 |
+
mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2))
|
| 312 |
+
mask -= mask.min()
|
| 313 |
+
mask /= mask.max()
|
| 314 |
+
lbl_blur_crop = lbl_blur[:, :, :, :Lxc]
|
| 315 |
+
imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask +
|
| 316 |
+
(1-mask) * imgi[iblur, :, :, :Lxc])
|
| 317 |
+
else:
|
| 318 |
+
imgi[iblur] = lbl_blur
|
| 319 |
+
|
| 320 |
+
imgi[~iblur] = lbl[~iblur]
|
| 321 |
+
|
| 322 |
+
# apply downsample
|
| 323 |
+
for k in ii:
|
| 324 |
+
i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]]
|
| 325 |
+
imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear")
|
| 326 |
+
|
| 327 |
+
# add poisson noise
|
| 328 |
+
ipoisson = np.random.rand(len(lbl)) < poisson
|
| 329 |
+
if ipoisson.sum() > 0:
|
| 330 |
+
if pscale is None:
|
| 331 |
+
pscale = torch.zeros(len(lbl))
|
| 332 |
+
m = torch.distributions.gamma.Gamma(alpha, beta)
|
| 333 |
+
pscale = torch.clamp(m.rsample(sample_shape=(ipoisson.sum(),)), 1.)
|
| 334 |
+
#pscale = torch.clamp(20 * (torch.rand(size=(len(lbl),), device=lbl.device)), 1.5)
|
| 335 |
+
pscale = pscale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
|
| 336 |
+
else:
|
| 337 |
+
pscale = pscale * torch.ones((ipoisson.sum(), 1, 1, 1), device=device)
|
| 338 |
+
imgi[ipoisson] = torch.poisson(pscale * imgi[ipoisson])
|
| 339 |
+
imgi[~ipoisson] = imgi[~ipoisson]
|
| 340 |
+
|
| 341 |
+
# renormalize
|
| 342 |
+
imgi = img_norm(imgi)
|
| 343 |
+
|
| 344 |
+
return imgi
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7,
|
| 348 |
+
downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30,
|
| 349 |
+
ds_max=7, uniform_blur=False, iso=True, rotate=True,
|
| 350 |
+
device=torch.device("cuda"), xy=(224, 224),
|
| 351 |
+
nchan_noise=1, keep_raw=True):
|
| 352 |
+
"""
|
| 353 |
+
Applies random rotation, resizing, and noise to the input data.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
data (numpy.ndarray): The input data.
|
| 357 |
+
labels (numpy.ndarray, optional): The flow and cellprob labels associated with the data. Defaults to None.
|
| 358 |
+
diams (float, optional): The diameter of the objects. Defaults to None.
|
| 359 |
+
poisson (float, optional): The Poisson noise probability. Defaults to 0.7.
|
| 360 |
+
blur (float, optional): The blur probability. Defaults to 0.7.
|
| 361 |
+
downsample (float, optional): The downsample probability. Defaults to 0.0.
|
| 362 |
+
beta (float, optional): The beta value for the poisson noise distribution. Defaults to 0.7.
|
| 363 |
+
gblur (float, optional): The Gaussian blur level. Defaults to 1.0.
|
| 364 |
+
diam_mean (float, optional): The mean diameter. Defaults to 30.
|
| 365 |
+
ds_max (int, optional): The maximum downsample value. Defaults to 7.
|
| 366 |
+
iso (bool, optional): Whether to apply isotropic augmentation. Defaults to True.
|
| 367 |
+
rotate (bool, optional): Whether to apply rotation augmentation. Defaults to True.
|
| 368 |
+
device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
|
| 369 |
+
xy (tuple, optional): The size of the output image. Defaults to (224, 224).
|
| 370 |
+
nchan_noise (int, optional): The number of channels to add noise to. Defaults to 1.
|
| 371 |
+
keep_raw (bool, optional): Whether to keep the raw image. Defaults to True.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
torch.Tensor: The augmented image and augmented noisy/blurry/downsampled version of image.
|
| 375 |
+
torch.Tensor: The augmented labels.
|
| 376 |
+
float: The scale factor applied to the image.
|
| 377 |
+
"""
|
| 378 |
+
if device == None:
|
| 379 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
|
| 380 |
+
|
| 381 |
+
diams = 30 if diams is None else diams
|
| 382 |
+
random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1))
|
| 383 |
+
random_rsc = diams / random_diam #/ random_diam
|
| 384 |
+
#rsc /= random_scale
|
| 385 |
+
xy0 = (340, 340)
|
| 386 |
+
nchan = data[0].shape[0]
|
| 387 |
+
data_new = np.zeros((len(data), (1 + keep_raw) * nchan, xy0[0], xy0[1]), "float32")
|
| 388 |
+
labels_new = np.zeros((len(data), 3, xy0[0], xy0[1]), "float32")
|
| 389 |
+
for i in range(
|
| 390 |
+
len(data)): #, (sc, img, lbl) in enumerate(zip(random_rsc, data, labels)):
|
| 391 |
+
sc = random_rsc[i]
|
| 392 |
+
img = data[i]
|
| 393 |
+
lbl = labels[i] if labels is not None else None
|
| 394 |
+
# create affine transform to resize
|
| 395 |
+
Ly, Lx = img.shape[-2:]
|
| 396 |
+
dxy = np.maximum(0, np.array([Lx / sc - xy0[1], Ly / sc - xy0[0]]))
|
| 397 |
+
dxy = (np.random.rand(2,) - .5) * dxy
|
| 398 |
+
cc = np.array([Lx / 2, Ly / 2])
|
| 399 |
+
cc1 = cc - np.array([Lx - xy0[1], Ly - xy0[0]]) / 2 + dxy
|
| 400 |
+
pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
|
| 401 |
+
pts2 = np.float32(
|
| 402 |
+
[cc1, cc1 + np.array([1, 0]) / sc, cc1 + np.array([0, 1]) / sc])
|
| 403 |
+
M = cv2.getAffineTransform(pts1, pts2)
|
| 404 |
+
|
| 405 |
+
# apply to image
|
| 406 |
+
for c in range(nchan):
|
| 407 |
+
img_rsz = cv2.warpAffine(img[c], M, xy0, flags=cv2.INTER_LINEAR)
|
| 408 |
+
#img_noise = add_noise(torch.from_numpy(img_rsz).to(device).unsqueeze(0)).cpu().numpy().squeeze(0)
|
| 409 |
+
data_new[i, c] = img_rsz
|
| 410 |
+
if keep_raw:
|
| 411 |
+
data_new[i, c + nchan] = img_rsz
|
| 412 |
+
|
| 413 |
+
if lbl is not None:
|
| 414 |
+
# apply to labels
|
| 415 |
+
labels_new[i, 0] = cv2.warpAffine(lbl[0], M, xy0, flags=cv2.INTER_NEAREST)
|
| 416 |
+
labels_new[i, 1] = cv2.warpAffine(lbl[1], M, xy0, flags=cv2.INTER_LINEAR)
|
| 417 |
+
labels_new[i, 2] = cv2.warpAffine(lbl[2], M, xy0, flags=cv2.INTER_LINEAR)
|
| 418 |
+
|
| 419 |
+
rsc = random_diam / diam_mean
|
| 420 |
+
|
| 421 |
+
# add noise before augmentations
|
| 422 |
+
img = torch.from_numpy(data_new).to(device)
|
| 423 |
+
img = torch.clamp(img, 0.)
|
| 424 |
+
# just add noise to cyto if nchan_noise=1
|
| 425 |
+
img[:, :nchan_noise] = add_noise(
|
| 426 |
+
img[:, :nchan_noise], poisson=poisson, blur=blur, ds_max=ds_max, iso=iso,
|
| 427 |
+
downsample=downsample, beta=beta, gblur=gblur,
|
| 428 |
+
diams=torch.from_numpy(random_diam).to(device).float())
|
| 429 |
+
# img -= img.mean(dim=(-2,-1), keepdim=True)
|
| 430 |
+
# img /= img.std(dim=(-2,-1), keepdim=True) + 1e-3
|
| 431 |
+
img = img.cpu().numpy()
|
| 432 |
+
|
| 433 |
+
# augmentations
|
| 434 |
+
img, lbl, scale = transforms.random_rotate_and_resize(
|
| 435 |
+
img,
|
| 436 |
+
Y=labels_new,
|
| 437 |
+
xy=xy,
|
| 438 |
+
rotate=False if not iso else rotate,
|
| 439 |
+
#(iso and downsample==0),
|
| 440 |
+
rescale=rsc,
|
| 441 |
+
scale_range=0.5)
|
| 442 |
+
img = torch.from_numpy(img).to(device)
|
| 443 |
+
lbl = torch.from_numpy(lbl).to(device)
|
| 444 |
+
|
| 445 |
+
return img, lbl, scale
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
|
| 449 |
+
"""
|
| 450 |
+
Creates a Cellpose network with a single input channel.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
device (str): The device to run the network on.
|
| 454 |
+
model_type (str, optional): The type of Cellpose model to use. Defaults to "cyto2".
|
| 455 |
+
pretrained_model (str, optional): The path to a pretrained model file. Defaults to None.
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
torch.nn.Module: The Cellpose network with a single input channel.
|
| 459 |
+
"""
|
| 460 |
+
if pretrained_model is not None and not os.path.exists(pretrained_model):
|
| 461 |
+
model_type = pretrained_model
|
| 462 |
+
pretrained_model = None
|
| 463 |
+
nbase = [32, 64, 128, 256]
|
| 464 |
+
nchan = 1
|
| 465 |
+
net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device)
|
| 466 |
+
filename = model_path(model_type,
|
| 467 |
+
0) if pretrained_model is None else pretrained_model
|
| 468 |
+
weights = torch.load(filename, weights_only=True)
|
| 469 |
+
zp = 0
|
| 470 |
+
print(filename)
|
| 471 |
+
for name in net1.state_dict():
|
| 472 |
+
if ("res_down_0.conv.conv_0" not in name and
|
| 473 |
+
#"output" not in name and
|
| 474 |
+
"res_down_0.proj" not in name and name != "diam_mean" and
|
| 475 |
+
name != "diam_labels"):
|
| 476 |
+
net1.state_dict()[name].copy_(weights[name])
|
| 477 |
+
elif "res_down_0" in name:
|
| 478 |
+
if len(weights[name].shape) > 0:
|
| 479 |
+
new_weight = torch.zeros_like(net1.state_dict()[name])
|
| 480 |
+
if weights[name].shape[0] == 2:
|
| 481 |
+
new_weight[:] = weights[name][0]
|
| 482 |
+
elif len(weights[name].shape) > 1 and weights[name].shape[1] == 2:
|
| 483 |
+
new_weight[:, zp] = weights[name][:, 0]
|
| 484 |
+
else:
|
| 485 |
+
new_weight = weights[name]
|
| 486 |
+
else:
|
| 487 |
+
new_weight = weights[name]
|
| 488 |
+
net1.state_dict()[name].copy_(new_weight)
|
| 489 |
+
return net1
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class CellposeDenoiseModel():
|
| 493 |
+
""" model to run Cellpose and Image restoration """
|
| 494 |
+
|
| 495 |
+
def __init__(self, gpu=False, pretrained_model=False, model_type=None,
|
| 496 |
+
restore_type="denoise_cyto3", nchan=2,
|
| 497 |
+
chan2_restore=False, device=None):
|
| 498 |
+
|
| 499 |
+
self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore,
|
| 500 |
+
device=device)
|
| 501 |
+
self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan,
|
| 502 |
+
pretrained_model=pretrained_model, device=device)
|
| 503 |
+
|
| 504 |
+
def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
|
| 505 |
+
normalize=True, rescale=None, diameter=None, tile_overlap=0.1,
|
| 506 |
+
augment=False, resample=True, invert=False, flow_threshold=0.4,
|
| 507 |
+
cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0,
|
| 508 |
+
min_size=15, niter=None, interp=True, bsize=224, flow3D_smooth=0):
|
| 509 |
+
"""
|
| 510 |
+
Restore array or list of images using the image restoration model, and then segment.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
|
| 514 |
+
batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
|
| 515 |
+
(can make smaller or bigger depending on GPU memory usage). Defaults to 8.
|
| 516 |
+
channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
|
| 517 |
+
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
|
| 518 |
+
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
|
| 519 |
+
For instance, to segment grayscale images, input [0,0]. To segment images with cells
|
| 520 |
+
in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
|
| 521 |
+
image with cells in green and nuclei in blue, input [[0,0], [2,3]].
|
| 522 |
+
Defaults to None.
|
| 523 |
+
channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
|
| 524 |
+
if None, channels dimension is attempted to be automatically determined. Defaults to None.
|
| 525 |
+
z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
|
| 526 |
+
if None, z dimension is attempted to be automatically determined. Defaults to None.
|
| 527 |
+
normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
|
| 528 |
+
can also pass dictionary of parameters (all keys are optional, default values shown):
|
| 529 |
+
- "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
|
| 530 |
+
- "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
|
| 531 |
+
- "normalize"=True ; run normalization (if False, all following parameters ignored)
|
| 532 |
+
- "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
|
| 533 |
+
- "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
|
| 534 |
+
- "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
| 535 |
+
Defaults to True.
|
| 536 |
+
rescale (float, optional): resize factor for each image, if None, set to 1.0;
|
| 537 |
+
(only used if diameter is None). Defaults to None.
|
| 538 |
+
diameter (float, optional): diameter for each image,
|
| 539 |
+
if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
|
| 540 |
+
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 541 |
+
augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
|
| 542 |
+
resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
|
| 543 |
+
invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
|
| 544 |
+
flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
|
| 545 |
+
cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
|
| 546 |
+
do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False.
|
| 547 |
+
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
|
| 548 |
+
stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
|
| 549 |
+
min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
|
| 550 |
+
flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0.
|
| 551 |
+
niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None.
|
| 552 |
+
interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
|
| 553 |
+
|
| 554 |
+
Returns:
|
| 555 |
+
A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels;
|
| 556 |
+
flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration;
|
| 557 |
+
styles: style vector summarizing each image of size 256;
|
| 558 |
+
imgs: Restored images.
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
if isinstance(normalize, dict):
|
| 562 |
+
normalize_params = {**normalize_default, **normalize}
|
| 563 |
+
elif not isinstance(normalize, bool):
|
| 564 |
+
raise ValueError("normalize parameter must be a bool or a dict")
|
| 565 |
+
else:
|
| 566 |
+
normalize_params = normalize_default
|
| 567 |
+
normalize_params["normalize"] = normalize
|
| 568 |
+
normalize_params["invert"] = invert
|
| 569 |
+
|
| 570 |
+
img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels,
|
| 571 |
+
channel_axis=channel_axis, z_axis=z_axis,
|
| 572 |
+
do_3D=do_3D,
|
| 573 |
+
normalize=normalize_params, rescale=rescale,
|
| 574 |
+
diameter=diameter,
|
| 575 |
+
tile_overlap=tile_overlap, bsize=bsize)
|
| 576 |
+
|
| 577 |
+
# turn off special normalization for segmentation
|
| 578 |
+
normalize_params = normalize_default
|
| 579 |
+
|
| 580 |
+
# change channels for segmentation
|
| 581 |
+
if channels is not None:
|
| 582 |
+
channels_new = [0, 0] if channels[0] == 0 else [1, 2]
|
| 583 |
+
else:
|
| 584 |
+
channels_new = None
|
| 585 |
+
# change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean)
|
| 586 |
+
diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
|
| 587 |
+
masks, flows, styles = self.cp.eval(
|
| 588 |
+
img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1,
|
| 589 |
+
z_axis=0 if not isinstance(img_restore, list) and img_restore.ndim > 3 and img_restore.shape[0] > 0 else None,
|
| 590 |
+
normalize=normalize_params, rescale=rescale, diameter=diameter,
|
| 591 |
+
tile_overlap=tile_overlap, augment=augment, resample=resample,
|
| 592 |
+
invert=invert, flow_threshold=flow_threshold,
|
| 593 |
+
cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy,
|
| 594 |
+
stitch_threshold=stitch_threshold, min_size=min_size, niter=niter,
|
| 595 |
+
interp=interp, bsize=bsize)
|
| 596 |
+
|
| 597 |
+
return masks, flows, styles, img_restore
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class DenoiseModel():
|
| 601 |
+
"""
|
| 602 |
+
DenoiseModel class for denoising images using Cellpose denoising model.
|
| 603 |
+
|
| 604 |
+
Args:
|
| 605 |
+
gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
|
| 606 |
+
pretrained_model (bool or str or Path, optional): Pretrained model to use for denoising.
|
| 607 |
+
Can be a string or path. Defaults to False.
|
| 608 |
+
nchan (int, optional): Number of channels in the input images, all Cellpose 3 models were trained with nchan=1. Defaults to 1.
|
| 609 |
+
model_type (str, optional): Type of pretrained model to use ("denoise_cyto3", "deblur_cyto3", "upsample_cyto3", ...). Defaults to None.
|
| 610 |
+
chan2 (bool, optional): Whether to use a separate model for the second channel. Defaults to False.
|
| 611 |
+
diam_mean (float, optional): Mean diameter of the objects in the images. Defaults to 30.0.
|
| 612 |
+
device (torch.device, optional): Device to use for computation. Defaults to None.
|
| 613 |
+
|
| 614 |
+
Attributes:
|
| 615 |
+
nchan (int): Number of channels in the input images.
|
| 616 |
+
diam_mean (float): Mean diameter of the objects in the images.
|
| 617 |
+
net (CPnet): Cellpose network for denoising.
|
| 618 |
+
pretrained_model (bool or str or Path): Pretrained model path to use for denoising.
|
| 619 |
+
net_chan2 (CPnet or None): Cellpose network for the second channel, if applicable.
|
| 620 |
+
net_type (str): Type of the denoising network.
|
| 621 |
+
|
| 622 |
+
Methods:
|
| 623 |
+
eval(x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
|
| 624 |
+
normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1)
|
| 625 |
+
Denoise array or list of images using the denoising model.
|
| 626 |
+
|
| 627 |
+
_eval(net, x, normalize=True, rescale=None, diameter=None, tile=True,
|
| 628 |
+
tile_overlap=0.1)
|
| 629 |
+
Run denoising model on a single channel.
|
| 630 |
+
"""
|
| 631 |
+
|
| 632 |
+
def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
|
| 633 |
+
chan2=False, diam_mean=30., device=None):
|
| 634 |
+
self.nchan = nchan
|
| 635 |
+
if pretrained_model and (not isinstance(pretrained_model, str) and
|
| 636 |
+
not isinstance(pretrained_model, Path)):
|
| 637 |
+
raise ValueError("pretrained_model must be a string or path")
|
| 638 |
+
|
| 639 |
+
self.diam_mean = diam_mean
|
| 640 |
+
builtin = True
|
| 641 |
+
if model_type is not None or (pretrained_model and
|
| 642 |
+
not os.path.exists(pretrained_model)):
|
| 643 |
+
pretrained_model_string = model_type if model_type is not None else "denoise_cyto3"
|
| 644 |
+
if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]):
|
| 645 |
+
pretrained_model_string = "denoise_cyto3"
|
| 646 |
+
pretrained_model = model_path(pretrained_model_string)
|
| 647 |
+
if (pretrained_model and not os.path.exists(pretrained_model)):
|
| 648 |
+
denoise_logger.warning("pretrained model has incorrect path")
|
| 649 |
+
denoise_logger.info(f">> {pretrained_model_string} << model set to be used")
|
| 650 |
+
self.diam_mean = 17. if "nuclei" in pretrained_model_string else 30.
|
| 651 |
+
else:
|
| 652 |
+
if pretrained_model:
|
| 653 |
+
builtin = False
|
| 654 |
+
pretrained_model_string = pretrained_model
|
| 655 |
+
denoise_logger.info(f">>>> loading model {pretrained_model_string}")
|
| 656 |
+
|
| 657 |
+
# assign network device
|
| 658 |
+
if device is None:
|
| 659 |
+
sdevice, gpu = assign_device(use_torch=True, gpu=gpu)
|
| 660 |
+
self.device = device if device is not None else sdevice
|
| 661 |
+
if device is not None:
|
| 662 |
+
device_gpu = self.device.type == "cuda"
|
| 663 |
+
self.gpu = gpu if device is None else device_gpu
|
| 664 |
+
|
| 665 |
+
# create network
|
| 666 |
+
self.nchan = nchan
|
| 667 |
+
self.nclasses = 1
|
| 668 |
+
nbase = [32, 64, 128, 256]
|
| 669 |
+
self.nchan = nchan
|
| 670 |
+
self.nbase = [nchan, *nbase]
|
| 671 |
+
|
| 672 |
+
self.net = CPnet(self.nbase, self.nclasses, sz=3,
|
| 673 |
+
max_pool=True, diam_mean=diam_mean).to(self.device)
|
| 674 |
+
|
| 675 |
+
self.pretrained_model = pretrained_model
|
| 676 |
+
self.net_chan2 = None
|
| 677 |
+
if self.pretrained_model:
|
| 678 |
+
self.net.load_model(self.pretrained_model, device=self.device)
|
| 679 |
+
denoise_logger.info(
|
| 680 |
+
f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)"
|
| 681 |
+
)
|
| 682 |
+
if chan2 and builtin:
|
| 683 |
+
chan2_path = model_path(
|
| 684 |
+
os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei")
|
| 685 |
+
print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}")
|
| 686 |
+
self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3,
|
| 687 |
+
max_pool=True,
|
| 688 |
+
diam_mean=17.).to(self.device)
|
| 689 |
+
self.net_chan2.load_model(chan2_path, device=self.device)
|
| 690 |
+
self.net_type = "cellpose_denoise"
|
| 691 |
+
|
| 692 |
+
def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
|
| 693 |
+
normalize=True, rescale=None, diameter=None, tile=True, do_3D=False,
|
| 694 |
+
tile_overlap=0.1, bsize=224):
|
| 695 |
+
"""
|
| 696 |
+
Restore array or list of images using the image restoration model.
|
| 697 |
+
|
| 698 |
+
Args:
|
| 699 |
+
x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
|
| 700 |
+
batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
|
| 701 |
+
(can make smaller or bigger depending on GPU memory usage). Defaults to 8.
|
| 702 |
+
channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
|
| 703 |
+
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
|
| 704 |
+
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
|
| 705 |
+
For instance, to segment grayscale images, input [0,0]. To segment images with cells
|
| 706 |
+
in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
|
| 707 |
+
image with cells in green and nuclei in blue, input [[0,0], [2,3]].
|
| 708 |
+
Defaults to None.
|
| 709 |
+
channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
|
| 710 |
+
if None, channels dimension is attempted to be automatically determined. Defaults to None.
|
| 711 |
+
z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
|
| 712 |
+
if None, z dimension is attempted to be automatically determined. Defaults to None.
|
| 713 |
+
normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
|
| 714 |
+
can also pass dictionary of parameters (all keys are optional, default values shown):
|
| 715 |
+
- "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
|
| 716 |
+
- "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
|
| 717 |
+
- "normalize"=True ; run normalization (if False, all following parameters ignored)
|
| 718 |
+
- "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
|
| 719 |
+
- "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
|
| 720 |
+
- "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
| 721 |
+
Defaults to True.
|
| 722 |
+
rescale (float, optional): resize factor for each image, if None, set to 1.0;
|
| 723 |
+
(only used if diameter is None). Defaults to None.
|
| 724 |
+
diameter (float, optional): diameter for each image,
|
| 725 |
+
if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
|
| 726 |
+
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 727 |
+
|
| 728 |
+
Returns:
|
| 729 |
+
list: A list of 2D/3D arrays of restored images
|
| 730 |
+
|
| 731 |
+
"""
|
| 732 |
+
if isinstance(x, list) or x.squeeze().ndim == 5:
|
| 733 |
+
tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO)
|
| 734 |
+
nimg = len(x)
|
| 735 |
+
iterator = trange(nimg, file=tqdm_out,
|
| 736 |
+
mininterval=30) if nimg > 1 else range(nimg)
|
| 737 |
+
imgs = []
|
| 738 |
+
for i in iterator:
|
| 739 |
+
imgi = self.eval(
|
| 740 |
+
x[i], batch_size=batch_size,
|
| 741 |
+
channels=channels[i] if channels is not None and
|
| 742 |
+
((len(channels) == len(x) and
|
| 743 |
+
(isinstance(channels[i], list) or
|
| 744 |
+
isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2))
|
| 745 |
+
else channels, channel_axis=channel_axis, z_axis=z_axis,
|
| 746 |
+
normalize=normalize,
|
| 747 |
+
do_3D=do_3D,
|
| 748 |
+
rescale=rescale[i] if isinstance(rescale, list) or
|
| 749 |
+
isinstance(rescale, np.ndarray) else rescale,
|
| 750 |
+
diameter=diameter[i] if isinstance(diameter, list) or
|
| 751 |
+
isinstance(diameter, np.ndarray) else diameter,
|
| 752 |
+
tile_overlap=tile_overlap, bsize=bsize)
|
| 753 |
+
imgs.append(imgi)
|
| 754 |
+
if isinstance(x, np.ndarray):
|
| 755 |
+
imgs = np.array(imgs)
|
| 756 |
+
return imgs
|
| 757 |
+
|
| 758 |
+
else:
|
| 759 |
+
# reshape image
|
| 760 |
+
x = transforms.convert_image(x, channels, channel_axis=channel_axis,
|
| 761 |
+
z_axis=z_axis, do_3D=do_3D, nchan=None)
|
| 762 |
+
if x.ndim < 4:
|
| 763 |
+
squeeze = True
|
| 764 |
+
x = x[np.newaxis, ...]
|
| 765 |
+
else:
|
| 766 |
+
squeeze = False
|
| 767 |
+
|
| 768 |
+
# may need to interpolate image before running upsampling
|
| 769 |
+
self.ratio = 1.
|
| 770 |
+
if "upsample" in self.pretrained_model:
|
| 771 |
+
Ly, Lx = x.shape[-3:-1]
|
| 772 |
+
if diameter is not None and 3 <= diameter < self.diam_mean:
|
| 773 |
+
self.ratio = self.diam_mean / diameter
|
| 774 |
+
denoise_logger.info(
|
| 775 |
+
f"upsampling image to {self.diam_mean} pixel diameter ({self.ratio:0.2f} times)"
|
| 776 |
+
)
|
| 777 |
+
Lyr, Lxr = int(Ly * self.ratio), int(Lx * self.ratio)
|
| 778 |
+
x = transforms.resize_image(x, Ly=Lyr, Lx=Lxr)
|
| 779 |
+
else:
|
| 780 |
+
denoise_logger.warning(
|
| 781 |
+
f"not interpolating image before upsampling because diameter is set >= {self.diam_mean}"
|
| 782 |
+
)
|
| 783 |
+
#raise ValueError(f"diameter is set to {diameter}, needs to be >=3 and < {self.dn.diam_mean}")
|
| 784 |
+
|
| 785 |
+
self.batch_size = batch_size
|
| 786 |
+
|
| 787 |
+
if diameter is not None and diameter > 0:
|
| 788 |
+
rescale = self.diam_mean / diameter
|
| 789 |
+
elif rescale is None:
|
| 790 |
+
rescale = 1.0
|
| 791 |
+
|
| 792 |
+
if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0):
|
| 793 |
+
x = x[..., :1]
|
| 794 |
+
|
| 795 |
+
for c in range(x.shape[-1]):
|
| 796 |
+
rescale0 = rescale * 30. / 17. if c == 1 else rescale
|
| 797 |
+
if c == 0 or self.net_chan2 is None:
|
| 798 |
+
x[...,
|
| 799 |
+
c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size,
|
| 800 |
+
normalize=normalize, rescale=rescale0,
|
| 801 |
+
tile_overlap=tile_overlap, bsize=bsize)[...,0]
|
| 802 |
+
else:
|
| 803 |
+
x[...,
|
| 804 |
+
c] = self._eval(self.net_chan2, x[...,
|
| 805 |
+
c:c + 1], batch_size=batch_size,
|
| 806 |
+
normalize=normalize, rescale=rescale0,
|
| 807 |
+
tile_overlap=tile_overlap, bsize=bsize)[...,0]
|
| 808 |
+
x = x[0] if squeeze else x
|
| 809 |
+
return x
|
| 810 |
+
|
| 811 |
+
def _eval(self, net, x, batch_size=8, normalize=True, rescale=None,
|
| 812 |
+
tile_overlap=0.1, bsize=224):
|
| 813 |
+
"""
|
| 814 |
+
Run image restoration model on a single channel.
|
| 815 |
+
|
| 816 |
+
Args:
|
| 817 |
+
x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
|
| 818 |
+
batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
|
| 819 |
+
(can make smaller or bigger depending on GPU memory usage). Defaults to 8.
|
| 820 |
+
normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
|
| 821 |
+
can also pass dictionary of parameters (all keys are optional, default values shown):
|
| 822 |
+
- "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
|
| 823 |
+
- "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
|
| 824 |
+
- "normalize"=True ; run normalization (if False, all following parameters ignored)
|
| 825 |
+
- "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
|
| 826 |
+
- "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
|
| 827 |
+
- "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
| 828 |
+
Defaults to True.
|
| 829 |
+
rescale (float, optional): resize factor for each image, if None, set to 1.0;
|
| 830 |
+
(only used if diameter is None). Defaults to None.
|
| 831 |
+
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 832 |
+
|
| 833 |
+
Returns:
|
| 834 |
+
list: A list of 2D/3D arrays of restored images
|
| 835 |
+
|
| 836 |
+
"""
|
| 837 |
+
if isinstance(normalize, dict):
|
| 838 |
+
normalize_params = {**normalize_default, **normalize}
|
| 839 |
+
elif not isinstance(normalize, bool):
|
| 840 |
+
raise ValueError("normalize parameter must be a bool or a dict")
|
| 841 |
+
else:
|
| 842 |
+
normalize_params = normalize_default
|
| 843 |
+
normalize_params["normalize"] = normalize
|
| 844 |
+
|
| 845 |
+
tic = time.time()
|
| 846 |
+
shape = x.shape
|
| 847 |
+
nimg = shape[0]
|
| 848 |
+
|
| 849 |
+
do_normalization = True if normalize_params["normalize"] else False
|
| 850 |
+
|
| 851 |
+
img = np.asarray(x)
|
| 852 |
+
if do_normalization:
|
| 853 |
+
img = transforms.normalize_img(img, **normalize_params)
|
| 854 |
+
if rescale != 1.0:
|
| 855 |
+
img = transforms.resize_image(img, rsz=rescale)
|
| 856 |
+
yf, style = run_net(self.net, img, bsize=bsize,
|
| 857 |
+
tile_overlap=tile_overlap)
|
| 858 |
+
yf = transforms.resize_image(yf, shape[1], shape[2])
|
| 859 |
+
imgs = yf
|
| 860 |
+
del yf, style
|
| 861 |
+
|
| 862 |
+
# imgs = np.zeros((*x.shape[:-1], 1), np.float32)
|
| 863 |
+
# for i in iterator:
|
| 864 |
+
# img = np.asarray(x[i])
|
| 865 |
+
# if do_normalization:
|
| 866 |
+
# img = transforms.normalize_img(img, **normalize_params)
|
| 867 |
+
# if rescale != 1.0:
|
| 868 |
+
# img = transforms.resize_image(img, rsz=[rescale, rescale])
|
| 869 |
+
# if img.ndim == 2:
|
| 870 |
+
# img = img[:, :, np.newaxis]
|
| 871 |
+
# yf, style = run_net(net, img, batch_size=batch_size, augment=False,
|
| 872 |
+
# tile=tile, tile_overlap=tile_overlap, bsize=bsize)
|
| 873 |
+
# img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2])
|
| 874 |
+
|
| 875 |
+
# if img.ndim == 2:
|
| 876 |
+
# img = img[:, :, np.newaxis]
|
| 877 |
+
# imgs[i] = img
|
| 878 |
+
# del yf, style
|
| 879 |
+
net_time = time.time() - tic
|
| 880 |
+
if nimg > 1:
|
| 881 |
+
denoise_logger.info("imgs denoised in %2.2fs" % (net_time))
|
| 882 |
+
|
| 883 |
+
return imgs
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
def train(net, train_data=None, train_labels=None, train_files=None, test_data=None,
|
| 887 |
+
test_labels=None, test_files=None, train_probs=None, test_probs=None,
|
| 888 |
+
lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None,
|
| 889 |
+
save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0,
|
| 890 |
+
iso=True, uniform_blur=False, downsample=0., ds_max=7,
|
| 891 |
+
learning_rate=0.005, n_epochs=500,
|
| 892 |
+
weight_decay=0.00001, batch_size=8, nimg_per_epoch=None,
|
| 893 |
+
nimg_test_per_epoch=None, model_name=None):
|
| 894 |
+
|
| 895 |
+
# net properties
|
| 896 |
+
device = net.device
|
| 897 |
+
nchan = net.nchan
|
| 898 |
+
diam_mean = net.diam_mean.item()
|
| 899 |
+
|
| 900 |
+
args = np.array([poisson, beta, blur, gblur, downsample])
|
| 901 |
+
if args.ndim == 1:
|
| 902 |
+
args = args[:, np.newaxis]
|
| 903 |
+
poisson, beta, blur, gblur, downsample = args
|
| 904 |
+
nnoise = len(poisson)
|
| 905 |
+
|
| 906 |
+
d = datetime.datetime.now()
|
| 907 |
+
if save_path is not None:
|
| 908 |
+
if model_name is None:
|
| 909 |
+
filename = ""
|
| 910 |
+
lstrs = ["per", "seg", "rec"]
|
| 911 |
+
for k, (l, s) in enumerate(zip(lam, lstrs)):
|
| 912 |
+
filename += f"{s}_{l:.2f}_"
|
| 913 |
+
if not iso:
|
| 914 |
+
filename += "aniso_"
|
| 915 |
+
if poisson.sum() > 0:
|
| 916 |
+
filename += "poisson_"
|
| 917 |
+
if blur.sum() > 0:
|
| 918 |
+
filename += "blur_"
|
| 919 |
+
if downsample.sum() > 0:
|
| 920 |
+
filename += "downsample_"
|
| 921 |
+
filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f")
|
| 922 |
+
filename = os.path.join(save_path, filename)
|
| 923 |
+
else:
|
| 924 |
+
filename = os.path.join(save_path, model_name)
|
| 925 |
+
print(filename)
|
| 926 |
+
for i in range(len(poisson)):
|
| 927 |
+
denoise_logger.info(
|
| 928 |
+
f"poisson: {poisson[i]: 0.2f}, beta: {beta[i]: 0.2f}, blur: {blur[i]: 0.2f}, gblur: {gblur[i]: 0.2f}, downsample: {downsample[i]: 0.2f}"
|
| 929 |
+
)
|
| 930 |
+
net1 = one_chan_cellpose(device=device, pretrained_model=seg_model_type)
|
| 931 |
+
|
| 932 |
+
learning_rate_const = learning_rate
|
| 933 |
+
LR = np.linspace(0, learning_rate_const, 10)
|
| 934 |
+
LR = np.append(LR, learning_rate_const * np.ones(n_epochs - 100))
|
| 935 |
+
for i in range(10):
|
| 936 |
+
LR = np.append(LR, LR[-1] / 2 * np.ones(10))
|
| 937 |
+
learning_rate = LR
|
| 938 |
+
|
| 939 |
+
batch_size = 8
|
| 940 |
+
optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate[0],
|
| 941 |
+
weight_decay=weight_decay)
|
| 942 |
+
if train_data is not None:
|
| 943 |
+
nimg = len(train_data)
|
| 944 |
+
diam_train = np.array(
|
| 945 |
+
[utils.diameters(train_labels[k])[0] for k in trange(len(train_labels))])
|
| 946 |
+
diam_train[diam_train < 5] = 5.
|
| 947 |
+
if test_data is not None:
|
| 948 |
+
diam_test = np.array(
|
| 949 |
+
[utils.diameters(test_labels[k])[0] for k in trange(len(test_labels))])
|
| 950 |
+
diam_test[diam_test < 5] = 5.
|
| 951 |
+
nimg_test = len(test_data)
|
| 952 |
+
else:
|
| 953 |
+
nimg = len(train_files)
|
| 954 |
+
denoise_logger.info(">>> using files instead of loading dataset")
|
| 955 |
+
train_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in train_files]
|
| 956 |
+
denoise_logger.info(">>> computing diameters")
|
| 957 |
+
diam_train = np.array([
|
| 958 |
+
utils.diameters(io.imread(train_labels_files[k])[0])[0]
|
| 959 |
+
for k in trange(len(train_labels_files))
|
| 960 |
+
])
|
| 961 |
+
diam_train[diam_train < 5] = 5.
|
| 962 |
+
if test_files is not None:
|
| 963 |
+
nimg_test = len(test_files)
|
| 964 |
+
test_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in test_files]
|
| 965 |
+
diam_test = np.array([
|
| 966 |
+
utils.diameters(io.imread(test_labels_files[k])[0])[0]
|
| 967 |
+
for k in trange(len(test_labels_files))
|
| 968 |
+
])
|
| 969 |
+
diam_test[diam_test < 5] = 5.
|
| 970 |
+
train_probs = 1. / nimg * np.ones(nimg,
|
| 971 |
+
"float64") if train_probs is None else train_probs
|
| 972 |
+
if test_files is not None or test_data is not None:
|
| 973 |
+
test_probs = 1. / nimg_test * np.ones(
|
| 974 |
+
nimg_test, "float64") if test_probs is None else test_probs
|
| 975 |
+
|
| 976 |
+
tic = time.time()
|
| 977 |
+
|
| 978 |
+
nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
|
| 979 |
+
if test_files is not None or test_data is not None:
|
| 980 |
+
nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
|
| 981 |
+
|
| 982 |
+
nbatch = 0
|
| 983 |
+
train_losses, test_losses = [], []
|
| 984 |
+
for iepoch in range(n_epochs):
|
| 985 |
+
np.random.seed(iepoch)
|
| 986 |
+
rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
|
| 987 |
+
p=train_probs)
|
| 988 |
+
torch.manual_seed(iepoch)
|
| 989 |
+
np.random.seed(iepoch)
|
| 990 |
+
for param_group in optimizer.param_groups:
|
| 991 |
+
param_group["lr"] = learning_rate[iepoch]
|
| 992 |
+
lavg, lavg_per, nsum = 0, 0, 0
|
| 993 |
+
for ibatch in range(0, nimg_per_epoch, batch_size * nnoise):
|
| 994 |
+
inds = rperm[ibatch : ibatch + batch_size * nnoise]
|
| 995 |
+
if train_data is None:
|
| 996 |
+
imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds]
|
| 997 |
+
lbls = [io.imread(train_labels_files[i])[1:] for i in inds]
|
| 998 |
+
else:
|
| 999 |
+
imgs = [train_data[i][:nchan] for i in inds]
|
| 1000 |
+
lbls = [train_labels[i][1:] for i in inds]
|
| 1001 |
+
#inoise = nbatch % nnoise
|
| 1002 |
+
rnoise = np.random.permutation(nnoise)
|
| 1003 |
+
for i, inoise in enumerate(rnoise):
|
| 1004 |
+
if i * batch_size < len(imgs):
|
| 1005 |
+
imgi, lbli, scale = random_rotate_and_resize_noise(
|
| 1006 |
+
imgs[i * batch_size : (i + 1) * batch_size],
|
| 1007 |
+
lbls[i * batch_size : (i + 1) * batch_size],
|
| 1008 |
+
diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(),
|
| 1009 |
+
poisson=poisson[inoise],
|
| 1010 |
+
beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso,
|
| 1011 |
+
downsample=downsample[inoise], uniform_blur=uniform_blur,
|
| 1012 |
+
diam_mean=diam_mean, ds_max=ds_max,
|
| 1013 |
+
device=device)
|
| 1014 |
+
if i == 0:
|
| 1015 |
+
img = imgi
|
| 1016 |
+
lbl = lbli
|
| 1017 |
+
else:
|
| 1018 |
+
img = torch.cat((img, imgi), axis=0)
|
| 1019 |
+
lbl = torch.cat((lbl, lbli), axis=0)
|
| 1020 |
+
|
| 1021 |
+
if nnoise > 0:
|
| 1022 |
+
iperm = np.random.permutation(img.shape[0])
|
| 1023 |
+
img, lbl = img[iperm], lbl[iperm]
|
| 1024 |
+
|
| 1025 |
+
for i in range(nnoise):
|
| 1026 |
+
optimizer.zero_grad()
|
| 1027 |
+
imgi = img[i * batch_size: (i + 1) * batch_size]
|
| 1028 |
+
lbli = lbl[i * batch_size: (i + 1) * batch_size]
|
| 1029 |
+
if imgi.shape[0] > 0:
|
| 1030 |
+
loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1,
|
| 1031 |
+
img=imgi[:, nchan:], lbl=lbli, lam=lam)
|
| 1032 |
+
loss.backward()
|
| 1033 |
+
optimizer.step()
|
| 1034 |
+
lavg += loss.item() * imgi.shape[0]
|
| 1035 |
+
lavg_per += loss_per.item() * imgi.shape[0]
|
| 1036 |
+
|
| 1037 |
+
nsum += len(img)
|
| 1038 |
+
nbatch += 1
|
| 1039 |
+
|
| 1040 |
+
if iepoch % 5 == 0 or iepoch < 10:
|
| 1041 |
+
lavg = lavg / nsum
|
| 1042 |
+
lavg_per = lavg_per / nsum
|
| 1043 |
+
if test_data is not None or test_files is not None:
|
| 1044 |
+
lavgt, nsum = 0., 0
|
| 1045 |
+
np.random.seed(42)
|
| 1046 |
+
rperm = np.random.choice(np.arange(0, nimg_test),
|
| 1047 |
+
size=(nimg_test_per_epoch,), p=test_probs)
|
| 1048 |
+
inoise = iepoch % nnoise
|
| 1049 |
+
torch.manual_seed(inoise)
|
| 1050 |
+
for ibatch in range(0, nimg_test_per_epoch, batch_size):
|
| 1051 |
+
inds = rperm[ibatch:ibatch + batch_size]
|
| 1052 |
+
if test_data is None:
|
| 1053 |
+
imgs = [
|
| 1054 |
+
np.maximum(0,
|
| 1055 |
+
io.imread(test_files[i])[:nchan]) for i in inds
|
| 1056 |
+
]
|
| 1057 |
+
lbls = [io.imread(test_labels_files[i])[1:] for i in inds]
|
| 1058 |
+
else:
|
| 1059 |
+
imgs = [test_data[i][:nchan] for i in inds]
|
| 1060 |
+
lbls = [test_labels[i][1:] for i in inds]
|
| 1061 |
+
img, lbl, scale = random_rotate_and_resize_noise(
|
| 1062 |
+
imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise],
|
| 1063 |
+
beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise],
|
| 1064 |
+
iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur,
|
| 1065 |
+
diam_mean=diam_mean, ds_max=ds_max, device=device)
|
| 1066 |
+
loss, loss_per = test_loss(net, img[:, :nchan], net1=net1,
|
| 1067 |
+
img=img[:, nchan:], lbl=lbl, lam=lam)
|
| 1068 |
+
|
| 1069 |
+
lavgt += loss.item() * img.shape[0]
|
| 1070 |
+
nsum += len(img)
|
| 1071 |
+
lavgt = lavgt / nsum
|
| 1072 |
+
denoise_logger.info(
|
| 1073 |
+
"Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f"
|
| 1074 |
+
% (iepoch, time.time() - tic, lavg, lavg_per, lavgt,
|
| 1075 |
+
learning_rate[iepoch]))
|
| 1076 |
+
test_losses.append(lavgt)
|
| 1077 |
+
else:
|
| 1078 |
+
denoise_logger.info(
|
| 1079 |
+
"Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" %
|
| 1080 |
+
(iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch]))
|
| 1081 |
+
train_losses.append(lavg)
|
| 1082 |
+
|
| 1083 |
+
if save_path is not None:
|
| 1084 |
+
if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
|
| 1085 |
+
if save_each: #separate files as model progresses
|
| 1086 |
+
filename0 = str(filename) + f"_epoch_{iepoch:%04d}"
|
| 1087 |
+
else:
|
| 1088 |
+
filename0 = filename
|
| 1089 |
+
denoise_logger.info(f"saving network parameters to {filename0}")
|
| 1090 |
+
net.save_model(filename0)
|
| 1091 |
+
else:
|
| 1092 |
+
filename = save_path
|
| 1093 |
+
|
| 1094 |
+
return filename, train_losses, test_losses
|
| 1095 |
+
|
| 1096 |
+
|
| 1097 |
+
if __name__ == "__main__":
|
| 1098 |
+
import argparse
|
| 1099 |
+
parser = argparse.ArgumentParser(description="cellpose parameters")
|
| 1100 |
+
|
| 1101 |
+
input_img_args = parser.add_argument_group("input image arguments")
|
| 1102 |
+
input_img_args.add_argument("--dir", default=[], type=str,
|
| 1103 |
+
help="folder containing data to run or train on.")
|
| 1104 |
+
input_img_args.add_argument("--img_filter", default=[], type=str,
|
| 1105 |
+
help="end string for images to run on")
|
| 1106 |
+
|
| 1107 |
+
model_args = parser.add_argument_group("model arguments")
|
| 1108 |
+
model_args.add_argument("--pretrained_model", default=[], type=str,
|
| 1109 |
+
help="pretrained denoising model")
|
| 1110 |
+
|
| 1111 |
+
training_args = parser.add_argument_group("training arguments")
|
| 1112 |
+
training_args.add_argument("--test_dir", default=[], type=str,
|
| 1113 |
+
help="folder containing test data (optional)")
|
| 1114 |
+
training_args.add_argument("--file_list", default=[], type=str,
|
| 1115 |
+
help="npy file containing list of train and test files")
|
| 1116 |
+
training_args.add_argument("--seg_model_type", default="cyto2", type=str,
|
| 1117 |
+
help="model to use for seg training loss")
|
| 1118 |
+
training_args.add_argument(
|
| 1119 |
+
"--noise_type", default=[], type=str,
|
| 1120 |
+
help="noise type to use (if input, then other noise params are ignored)")
|
| 1121 |
+
training_args.add_argument("--poisson", default=0.8, type=float,
|
| 1122 |
+
help="fraction of images to add poisson noise to")
|
| 1123 |
+
training_args.add_argument("--beta", default=0.7, type=float,
|
| 1124 |
+
help="scale of poisson noise")
|
| 1125 |
+
training_args.add_argument("--blur", default=0., type=float,
|
| 1126 |
+
help="fraction of images to blur")
|
| 1127 |
+
training_args.add_argument("--gblur", default=1.0, type=float,
|
| 1128 |
+
help="scale of gaussian blurring stddev")
|
| 1129 |
+
training_args.add_argument("--downsample", default=0., type=float,
|
| 1130 |
+
help="fraction of images to downsample")
|
| 1131 |
+
training_args.add_argument("--ds_max", default=7, type=int,
|
| 1132 |
+
help="max downsampling factor")
|
| 1133 |
+
training_args.add_argument("--lam_per", default=1.0, type=float,
|
| 1134 |
+
help="weighting of perceptual loss")
|
| 1135 |
+
training_args.add_argument("--lam_seg", default=1.5, type=float,
|
| 1136 |
+
help="weighting of segmentation loss")
|
| 1137 |
+
training_args.add_argument("--lam_rec", default=0., type=float,
|
| 1138 |
+
help="weighting of reconstruction loss")
|
| 1139 |
+
training_args.add_argument(
|
| 1140 |
+
"--diam_mean", default=30., type=float, help=
|
| 1141 |
+
"mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0"
|
| 1142 |
+
)
|
| 1143 |
+
training_args.add_argument("--learning_rate", default=0.001, type=float,
|
| 1144 |
+
help="learning rate. Default: %(default)s")
|
| 1145 |
+
training_args.add_argument("--n_epochs", default=2000, type=int,
|
| 1146 |
+
help="number of epochs. Default: %(default)s")
|
| 1147 |
+
training_args.add_argument(
|
| 1148 |
+
"--save_each", default=False, action="store_true",
|
| 1149 |
+
help="save each epoch as separate model")
|
| 1150 |
+
training_args.add_argument(
|
| 1151 |
+
"--nimg_per_epoch", default=0, type=int,
|
| 1152 |
+
help="number of images per epoch. Default is length of training images")
|
| 1153 |
+
training_args.add_argument(
|
| 1154 |
+
"--nimg_test_per_epoch", default=0, type=int,
|
| 1155 |
+
help="number of test images per epoch. Default is length of testing images")
|
| 1156 |
+
|
| 1157 |
+
io.logger_setup()
|
| 1158 |
+
|
| 1159 |
+
args = parser.parse_args()
|
| 1160 |
+
lams = [args.lam_per, args.lam_seg, args.lam_rec]
|
| 1161 |
+
print("lam", lams)
|
| 1162 |
+
|
| 1163 |
+
if len(args.noise_type) > 0:
|
| 1164 |
+
noise_type = args.noise_type
|
| 1165 |
+
uniform_blur = False
|
| 1166 |
+
iso = True
|
| 1167 |
+
if noise_type == "poisson":
|
| 1168 |
+
poisson = 0.8
|
| 1169 |
+
blur = 0.
|
| 1170 |
+
downsample = 0.
|
| 1171 |
+
beta = 0.7
|
| 1172 |
+
gblur = 1.0
|
| 1173 |
+
elif noise_type == "blur_expr":
|
| 1174 |
+
poisson = 0.8
|
| 1175 |
+
blur = 0.8
|
| 1176 |
+
downsample = 0.
|
| 1177 |
+
beta = 0.1
|
| 1178 |
+
gblur = 0.5
|
| 1179 |
+
elif noise_type == "blur":
|
| 1180 |
+
poisson = 0.8
|
| 1181 |
+
blur = 0.8
|
| 1182 |
+
downsample = 0.
|
| 1183 |
+
beta = 0.1
|
| 1184 |
+
gblur = 10.0
|
| 1185 |
+
uniform_blur = True
|
| 1186 |
+
elif noise_type == "downsample_expr":
|
| 1187 |
+
poisson = 0.8
|
| 1188 |
+
blur = 0.8
|
| 1189 |
+
downsample = 0.8
|
| 1190 |
+
beta = 0.03
|
| 1191 |
+
gblur = 1.0
|
| 1192 |
+
elif noise_type == "downsample":
|
| 1193 |
+
poisson = 0.8
|
| 1194 |
+
blur = 0.8
|
| 1195 |
+
downsample = 0.8
|
| 1196 |
+
beta = 0.03
|
| 1197 |
+
gblur = 5.0
|
| 1198 |
+
uniform_blur = True
|
| 1199 |
+
elif noise_type == "all":
|
| 1200 |
+
poisson = [0.8, 0.8, 0.8]
|
| 1201 |
+
blur = [0., 0.8, 0.8]
|
| 1202 |
+
downsample = [0., 0., 0.8]
|
| 1203 |
+
beta = [0.7, 0.1, 0.03]
|
| 1204 |
+
gblur = [0., 10.0, 5.0]
|
| 1205 |
+
uniform_blur = True
|
| 1206 |
+
elif noise_type == "aniso":
|
| 1207 |
+
poisson = 0.8
|
| 1208 |
+
blur = 0.8
|
| 1209 |
+
downsample = 0.8
|
| 1210 |
+
beta = 0.1
|
| 1211 |
+
gblur = args.ds_max * 1.5
|
| 1212 |
+
iso = False
|
| 1213 |
+
else:
|
| 1214 |
+
raise ValueError(f"{noise_type} noise_type is not supported")
|
| 1215 |
+
else:
|
| 1216 |
+
poisson, beta = args.poisson, args.beta
|
| 1217 |
+
blur, gblur = args.blur, args.gblur
|
| 1218 |
+
downsample = args.downsample
|
| 1219 |
+
|
| 1220 |
+
pretrained_model = None if len(
|
| 1221 |
+
args.pretrained_model) == 0 else args.pretrained_model
|
| 1222 |
+
model = DenoiseModel(gpu=True, nchan=1, diam_mean=args.diam_mean,
|
| 1223 |
+
pretrained_model=pretrained_model)
|
| 1224 |
+
|
| 1225 |
+
train_data, labels, train_files, train_probs = None, None, None, None
|
| 1226 |
+
test_data, test_labels, test_files, test_probs = None, None, None, None
|
| 1227 |
+
if len(args.file_list) == 0:
|
| 1228 |
+
output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0)
|
| 1229 |
+
images, labels, image_names, test_images, test_labels, image_names_test = output
|
| 1230 |
+
train_data = []
|
| 1231 |
+
for i in range(len(images)):
|
| 1232 |
+
img = images[i].astype("float32")
|
| 1233 |
+
if img.ndim > 2:
|
| 1234 |
+
img = img[0]
|
| 1235 |
+
train_data.append(
|
| 1236 |
+
np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
|
| 1237 |
+
if len(args.test_dir) > 0:
|
| 1238 |
+
test_data = []
|
| 1239 |
+
for i in range(len(test_images)):
|
| 1240 |
+
img = test_images[i].astype("float32")
|
| 1241 |
+
if img.ndim > 2:
|
| 1242 |
+
img = img[0]
|
| 1243 |
+
test_data.append(
|
| 1244 |
+
np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
|
| 1245 |
+
save_path = os.path.join(args.dir, "../models/")
|
| 1246 |
+
else:
|
| 1247 |
+
root = args.dir
|
| 1248 |
+
denoise_logger.info(
|
| 1249 |
+
">>> using file_list (assumes images are normalized and have flows!)")
|
| 1250 |
+
dat = np.load(args.file_list, allow_pickle=True).item()
|
| 1251 |
+
train_files = dat["train_files"]
|
| 1252 |
+
test_files = dat["test_files"]
|
| 1253 |
+
train_probs = dat["train_probs"] if "train_probs" in dat else None
|
| 1254 |
+
test_probs = dat["test_probs"] if "test_probs" in dat else None
|
| 1255 |
+
if str(train_files[0])[:len(str(root))] != str(root):
|
| 1256 |
+
for i in range(len(train_files)):
|
| 1257 |
+
new_path = root / Path(*train_files[i].parts[-3:])
|
| 1258 |
+
if i == 0:
|
| 1259 |
+
print(f"changing path from {train_files[i]} to {new_path}")
|
| 1260 |
+
train_files[i] = new_path
|
| 1261 |
+
|
| 1262 |
+
for i in range(len(test_files)):
|
| 1263 |
+
new_path = root / Path(*test_files[i].parts[-3:])
|
| 1264 |
+
test_files[i] = new_path
|
| 1265 |
+
save_path = os.path.join(args.dir, "models/")
|
| 1266 |
+
|
| 1267 |
+
os.makedirs(save_path, exist_ok=True)
|
| 1268 |
+
|
| 1269 |
+
nimg_per_epoch = None if args.nimg_per_epoch == 0 else args.nimg_per_epoch
|
| 1270 |
+
nimg_test_per_epoch = None if args.nimg_test_per_epoch == 0 else args.nimg_test_per_epoch
|
| 1271 |
+
|
| 1272 |
+
model_path = train(
|
| 1273 |
+
model.net, train_data=train_data, train_labels=labels, train_files=train_files,
|
| 1274 |
+
test_data=test_data, test_labels=test_labels, test_files=test_files,
|
| 1275 |
+
train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta,
|
| 1276 |
+
blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max,
|
| 1277 |
+
iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs,
|
| 1278 |
+
learning_rate=args.learning_rate,
|
| 1279 |
+
lam=lams,
|
| 1280 |
+
seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch,
|
| 1281 |
+
nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path)
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
def seg_train_noisy(model, train_data, train_labels, test_data=None, test_labels=None,
|
| 1285 |
+
poisson=0.8, blur=0.0, downsample=0.0, save_path=None,
|
| 1286 |
+
save_every=100, save_each=False, learning_rate=0.2, n_epochs=500,
|
| 1287 |
+
momentum=0.9, weight_decay=0.00001, SGD=True, batch_size=8,
|
| 1288 |
+
nimg_per_epoch=None, diameter=None, rescale=True, z_masking=False,
|
| 1289 |
+
model_name=None):
|
| 1290 |
+
""" train function uses loss function model.loss_fn in models.py
|
| 1291 |
+
|
| 1292 |
+
(data should already be normalized)
|
| 1293 |
+
|
| 1294 |
+
"""
|
| 1295 |
+
|
| 1296 |
+
d = datetime.datetime.now()
|
| 1297 |
+
|
| 1298 |
+
model.n_epochs = n_epochs
|
| 1299 |
+
if isinstance(learning_rate, (list, np.ndarray)):
|
| 1300 |
+
if isinstance(learning_rate, np.ndarray) and learning_rate.ndim > 1:
|
| 1301 |
+
raise ValueError("learning_rate.ndim must equal 1")
|
| 1302 |
+
elif len(learning_rate) != n_epochs:
|
| 1303 |
+
raise ValueError(
|
| 1304 |
+
"if learning_rate given as list or np.ndarray it must have length n_epochs"
|
| 1305 |
+
)
|
| 1306 |
+
model.learning_rate = learning_rate
|
| 1307 |
+
model.learning_rate_const = mode(learning_rate)[0][0]
|
| 1308 |
+
else:
|
| 1309 |
+
model.learning_rate_const = learning_rate
|
| 1310 |
+
# set learning rate schedule
|
| 1311 |
+
if SGD:
|
| 1312 |
+
LR = np.linspace(0, model.learning_rate_const, 10)
|
| 1313 |
+
if model.n_epochs > 250:
|
| 1314 |
+
LR = np.append(
|
| 1315 |
+
LR, model.learning_rate_const * np.ones(model.n_epochs - 100))
|
| 1316 |
+
for i in range(10):
|
| 1317 |
+
LR = np.append(LR, LR[-1] / 2 * np.ones(10))
|
| 1318 |
+
else:
|
| 1319 |
+
LR = np.append(
|
| 1320 |
+
LR,
|
| 1321 |
+
model.learning_rate_const * np.ones(max(0, model.n_epochs - 10)))
|
| 1322 |
+
else:
|
| 1323 |
+
LR = model.learning_rate_const * np.ones(model.n_epochs)
|
| 1324 |
+
model.learning_rate = LR
|
| 1325 |
+
|
| 1326 |
+
model.batch_size = batch_size
|
| 1327 |
+
model._set_optimizer(model.learning_rate[0], momentum, weight_decay, SGD)
|
| 1328 |
+
model._set_criterion()
|
| 1329 |
+
|
| 1330 |
+
nimg = len(train_data)
|
| 1331 |
+
|
| 1332 |
+
# compute average cell diameter
|
| 1333 |
+
if diameter is None:
|
| 1334 |
+
diam_train = np.array(
|
| 1335 |
+
[utils.diameters(train_labels[k][0])[0] for k in range(len(train_labels))])
|
| 1336 |
+
diam_train_mean = diam_train[diam_train > 0].mean()
|
| 1337 |
+
model.diam_labels = diam_train_mean
|
| 1338 |
+
if rescale:
|
| 1339 |
+
diam_train[diam_train < 5] = 5.
|
| 1340 |
+
if test_data is not None:
|
| 1341 |
+
diam_test = np.array([
|
| 1342 |
+
utils.diameters(test_labels[k][0])[0]
|
| 1343 |
+
for k in range(len(test_labels))
|
| 1344 |
+
])
|
| 1345 |
+
diam_test[diam_test < 5] = 5.
|
| 1346 |
+
denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
|
| 1347 |
+
elif rescale:
|
| 1348 |
+
diam_train_mean = diameter
|
| 1349 |
+
model.diam_labels = diameter
|
| 1350 |
+
denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
|
| 1351 |
+
diam_train = diameter * np.ones(len(train_labels), "float32")
|
| 1352 |
+
if test_data is not None:
|
| 1353 |
+
diam_test = diameter * np.ones(len(test_labels), "float32")
|
| 1354 |
+
|
| 1355 |
+
denoise_logger.info(
|
| 1356 |
+
f">>>> mean of training label mask diameters (saved to model) {diam_train_mean:.3f}"
|
| 1357 |
+
)
|
| 1358 |
+
model.net.diam_labels.data = torch.ones(1, device=model.device) * diam_train_mean
|
| 1359 |
+
|
| 1360 |
+
nchan = train_data[0].shape[0]
|
| 1361 |
+
denoise_logger.info(">>>> training network with %d channel input <<<<" % nchan)
|
| 1362 |
+
denoise_logger.info(">>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f" %
|
| 1363 |
+
(model.learning_rate_const, model.batch_size, weight_decay))
|
| 1364 |
+
|
| 1365 |
+
if test_data is not None:
|
| 1366 |
+
denoise_logger.info(f">>>> ntrain = {nimg}, ntest = {len(test_data)}")
|
| 1367 |
+
else:
|
| 1368 |
+
denoise_logger.info(f">>>> ntrain = {nimg}")
|
| 1369 |
+
|
| 1370 |
+
tic = time.time()
|
| 1371 |
+
|
| 1372 |
+
lavg, nsum = 0, 0
|
| 1373 |
+
|
| 1374 |
+
if save_path is not None:
|
| 1375 |
+
_, file_label = os.path.split(save_path)
|
| 1376 |
+
file_path = os.path.join(save_path, "models/")
|
| 1377 |
+
|
| 1378 |
+
if not os.path.exists(file_path):
|
| 1379 |
+
os.makedirs(file_path)
|
| 1380 |
+
else:
|
| 1381 |
+
denoise_logger.warning("WARNING: no save_path given, model not saving")
|
| 1382 |
+
|
| 1383 |
+
ksave = 0
|
| 1384 |
+
|
| 1385 |
+
# get indices for each epoch for training
|
| 1386 |
+
np.random.seed(0)
|
| 1387 |
+
inds_all = np.zeros((0,), "int32")
|
| 1388 |
+
if nimg_per_epoch is None or nimg > nimg_per_epoch:
|
| 1389 |
+
nimg_per_epoch = nimg
|
| 1390 |
+
denoise_logger.info(f">>>> nimg_per_epoch = {nimg_per_epoch}")
|
| 1391 |
+
while len(inds_all) < n_epochs * nimg_per_epoch:
|
| 1392 |
+
rperm = np.random.permutation(nimg)
|
| 1393 |
+
inds_all = np.hstack((inds_all, rperm))
|
| 1394 |
+
|
| 1395 |
+
for iepoch in range(model.n_epochs):
|
| 1396 |
+
if SGD:
|
| 1397 |
+
model._set_learning_rate(model.learning_rate[iepoch])
|
| 1398 |
+
np.random.seed(iepoch)
|
| 1399 |
+
rperm = inds_all[iepoch * nimg_per_epoch:(iepoch + 1) * nimg_per_epoch]
|
| 1400 |
+
for ibatch in range(0, nimg_per_epoch, batch_size):
|
| 1401 |
+
inds = rperm[ibatch:ibatch + batch_size]
|
| 1402 |
+
imgi, lbl, scale = random_rotate_and_resize_noise(
|
| 1403 |
+
[train_data[i] for i in inds], [train_labels[i][1:] for i in inds],
|
| 1404 |
+
poisson=poisson, blur=blur, downsample=downsample,
|
| 1405 |
+
diams=diam_train[inds], diam_mean=model.diam_mean)
|
| 1406 |
+
imgi = imgi[:, :1] # keep noisy only
|
| 1407 |
+
if z_masking:
|
| 1408 |
+
nc = imgi.shape[1]
|
| 1409 |
+
nb = imgi.shape[0]
|
| 1410 |
+
ncmin = (np.random.rand(nb) > 0.25) * (np.random.randint(
|
| 1411 |
+
nc // 2 - 1, size=nb))
|
| 1412 |
+
ncmax = nc - (np.random.rand(nb) > 0.25) * (np.random.randint(
|
| 1413 |
+
nc // 2 - 1, size=nb))
|
| 1414 |
+
for b in range(nb):
|
| 1415 |
+
imgi[b, :ncmin[b]] = 0
|
| 1416 |
+
imgi[b, ncmax[b]:] = 0
|
| 1417 |
+
|
| 1418 |
+
train_loss = model._train_step(imgi, lbl)
|
| 1419 |
+
lavg += train_loss
|
| 1420 |
+
nsum += len(imgi)
|
| 1421 |
+
|
| 1422 |
+
if iepoch % 10 == 0 or iepoch == 5:
|
| 1423 |
+
lavg = lavg / nsum
|
| 1424 |
+
if test_data is not None:
|
| 1425 |
+
lavgt, nsum = 0., 0
|
| 1426 |
+
np.random.seed(42)
|
| 1427 |
+
rperm = np.arange(0, len(test_data), 1, int)
|
| 1428 |
+
for ibatch in range(0, len(test_data), batch_size):
|
| 1429 |
+
inds = rperm[ibatch:ibatch + batch_size]
|
| 1430 |
+
imgi, lbl, scale = random_rotate_and_resize_noise(
|
| 1431 |
+
[test_data[i] for i in inds],
|
| 1432 |
+
[test_labels[i][1:] for i in inds], poisson=poisson, blur=blur,
|
| 1433 |
+
downsample=downsample, diams=diam_test[inds],
|
| 1434 |
+
diam_mean=model.diam_mean)
|
| 1435 |
+
imgi = imgi[:, :1] # keep noisy only
|
| 1436 |
+
test_loss = model._test_eval(imgi, lbl)
|
| 1437 |
+
lavgt += test_loss
|
| 1438 |
+
nsum += len(imgi)
|
| 1439 |
+
|
| 1440 |
+
denoise_logger.info(
|
| 1441 |
+
"Epoch %d, Time %4.1fs, Loss %2.4f, Loss Test %2.4f, LR %2.4f" %
|
| 1442 |
+
(iepoch, time.time() - tic, lavg, lavgt / nsum,
|
| 1443 |
+
model.learning_rate[iepoch]))
|
| 1444 |
+
else:
|
| 1445 |
+
denoise_logger.info(
|
| 1446 |
+
"Epoch %d, Time %4.1fs, Loss %2.4f, LR %2.4f" %
|
| 1447 |
+
(iepoch, time.time() - tic, lavg, model.learning_rate[iepoch]))
|
| 1448 |
+
|
| 1449 |
+
lavg, nsum = 0, 0
|
| 1450 |
+
|
| 1451 |
+
if save_path is not None:
|
| 1452 |
+
if iepoch == model.n_epochs - 1 or iepoch % save_every == 1:
|
| 1453 |
+
# save model at the end
|
| 1454 |
+
if save_each: #separate files as model progresses
|
| 1455 |
+
if model_name is None:
|
| 1456 |
+
filename = "{}_{}_{}_{}".format(
|
| 1457 |
+
model.net_type, file_label,
|
| 1458 |
+
d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch))
|
| 1459 |
+
else:
|
| 1460 |
+
filename = "{}_{}".format(model_name, "epoch_" + str(iepoch))
|
| 1461 |
+
else:
|
| 1462 |
+
if model_name is None:
|
| 1463 |
+
filename = "{}_{}_{}".format(model.net_type, file_label,
|
| 1464 |
+
d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
|
| 1465 |
+
else:
|
| 1466 |
+
filename = model_name
|
| 1467 |
+
filename = os.path.join(file_path, filename)
|
| 1468 |
+
ksave += 1
|
| 1469 |
+
denoise_logger.info(f"saving network parameters to {filename}")
|
| 1470 |
+
model.net.save_model(filename)
|
| 1471 |
+
else:
|
| 1472 |
+
filename = save_path
|
| 1473 |
+
|
| 1474 |
+
return filename
|
models/seg_post_model/cellpose/dynamics.py
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from scipy.ndimage import find_objects, center_of_mass, mean
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import tifffile
|
| 9 |
+
from tqdm import trange
|
| 10 |
+
import fastremap
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
dynamics_logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
from . import utils
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
|
| 22 |
+
device=torch.device("cpu")):
|
| 23 |
+
"""Runs diffusion on GPU to generate flows for training images or quality control.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
neighbors (torch.Tensor): 9 x pixels in masks.
|
| 27 |
+
meds (torch.Tensor): Mask centers.
|
| 28 |
+
isneighbor (torch.Tensor): Valid neighbor boolean 9 x pixels.
|
| 29 |
+
shape (tuple): Shape of the tensor.
|
| 30 |
+
n_iter (int, optional): Number of iterations. Defaults to 200.
|
| 31 |
+
device (torch.device, optional): Device to run the computation on. Defaults to torch.device("cpu").
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
torch.Tensor: Generated flows.
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps":
|
| 38 |
+
T = torch.zeros(shape, dtype=torch.float, device=device)
|
| 39 |
+
else:
|
| 40 |
+
T = torch.zeros(shape, dtype=torch.double, device=device)
|
| 41 |
+
|
| 42 |
+
for i in range(n_iter):
|
| 43 |
+
T[tuple(meds.T)] += 1
|
| 44 |
+
Tneigh = T[tuple(neighbors)]
|
| 45 |
+
Tneigh *= isneighbor
|
| 46 |
+
T[tuple(neighbors[:, 0])] = Tneigh.mean(axis=0)
|
| 47 |
+
del meds, isneighbor, Tneigh
|
| 48 |
+
|
| 49 |
+
if T.ndim == 2:
|
| 50 |
+
grads = T[neighbors[0, [2, 1, 4, 3]], neighbors[1, [2, 1, 4, 3]]]
|
| 51 |
+
del neighbors
|
| 52 |
+
dy = grads[0] - grads[1]
|
| 53 |
+
dx = grads[2] - grads[3]
|
| 54 |
+
del grads
|
| 55 |
+
mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
|
| 56 |
+
else:
|
| 57 |
+
grads = T[tuple(neighbors[:, 1:])]
|
| 58 |
+
del neighbors
|
| 59 |
+
dz = grads[0] - grads[1]
|
| 60 |
+
dy = grads[2] - grads[3]
|
| 61 |
+
dx = grads[4] - grads[5]
|
| 62 |
+
del grads
|
| 63 |
+
mu_torch = np.stack(
|
| 64 |
+
(dz.cpu().squeeze(0), dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
|
| 65 |
+
return mu_torch
|
| 66 |
+
|
| 67 |
+
def center_of_mass(mask):
|
| 68 |
+
yi, xi = np.nonzero(mask)
|
| 69 |
+
ymean = int(np.round(yi.sum() / len(yi)))
|
| 70 |
+
xmean = int(np.round(xi.sum() / len(xi)))
|
| 71 |
+
if not ((yi==ymean) * (xi==xmean)).sum():
|
| 72 |
+
# center is closest point to (ymean, xmean) within mask
|
| 73 |
+
imin = ((xi - xmean)**2 + (yi - ymean)**2).argmin()
|
| 74 |
+
ymean = yi[imin]
|
| 75 |
+
xmean = xi[imin]
|
| 76 |
+
|
| 77 |
+
return ymean, xmean
|
| 78 |
+
|
| 79 |
+
def get_centers(masks, slices):
|
| 80 |
+
centers = [center_of_mass(masks[slices[i]]==(i+1)) for i in range(len(slices))]
|
| 81 |
+
centers = np.array([np.array([centers[i][0] + slices[i][0].start, centers[i][1] + slices[i][1].start])
|
| 82 |
+
for i in range(len(slices))])
|
| 83 |
+
exts = np.array([(slc[0].stop - slc[0].start) + (slc[1].stop - slc[1].start) + 2 for slc in slices])
|
| 84 |
+
return centers, exts
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
|
| 88 |
+
"""Convert masks to flows using diffusion from center pixel.
|
| 89 |
+
|
| 90 |
+
Center of masks where diffusion starts is defined by pixel closest to median within the mask.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
|
| 94 |
+
device (torch.device, optional): The device to run the computation on. Defaults to torch.device("cpu").
|
| 95 |
+
niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY.
|
| 103 |
+
meds_p are cell centers.
|
| 104 |
+
"""
|
| 105 |
+
if device is None:
|
| 106 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
|
| 107 |
+
|
| 108 |
+
if masks.max() > 0:
|
| 109 |
+
Ly0, Lx0 = masks.shape
|
| 110 |
+
Ly, Lx = Ly0 + 2, Lx0 + 2
|
| 111 |
+
|
| 112 |
+
masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
|
| 113 |
+
masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
|
| 114 |
+
shape = masks_padded.shape
|
| 115 |
+
|
| 116 |
+
### get mask pixel neighbors
|
| 117 |
+
y, x = torch.nonzero(masks_padded, as_tuple=True)
|
| 118 |
+
y = y.int()
|
| 119 |
+
x = x.int()
|
| 120 |
+
neighbors = torch.zeros((2, 9, y.shape[0]), dtype=torch.int, device=device)
|
| 121 |
+
yxi = [[0, -1, 1, 0, 0, -1, -1, 1, 1], [0, 0, 0, -1, 1, -1, 1, -1, 1]]
|
| 122 |
+
for i in range(9):
|
| 123 |
+
neighbors[0, i] = y + yxi[0][i]
|
| 124 |
+
neighbors[1, i] = x + yxi[1][i]
|
| 125 |
+
isneighbor = torch.ones((9, y.shape[0]), dtype=torch.bool, device=device)
|
| 126 |
+
m0 = masks_padded[neighbors[0, 0], neighbors[1, 0]]
|
| 127 |
+
for i in range(1, 9):
|
| 128 |
+
isneighbor[i] = masks_padded[neighbors[0, i], neighbors[1, i]] == m0
|
| 129 |
+
del m0, masks_padded
|
| 130 |
+
|
| 131 |
+
### get center-of-mass within cell
|
| 132 |
+
slices = find_objects(masks)
|
| 133 |
+
centers, ext = get_centers(masks, slices)
|
| 134 |
+
meds_p = torch.from_numpy(centers).to(device).long()
|
| 135 |
+
meds_p += 1 # for padding
|
| 136 |
+
|
| 137 |
+
### run diffusion
|
| 138 |
+
n_iter = 2 * ext.max() if niter is None else niter
|
| 139 |
+
mu = _extend_centers_gpu(neighbors, meds_p, isneighbor, shape, n_iter=n_iter,
|
| 140 |
+
device=device)
|
| 141 |
+
mu = mu.astype("float64")
|
| 142 |
+
|
| 143 |
+
# new normalization
|
| 144 |
+
mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
|
| 145 |
+
|
| 146 |
+
# put into original image
|
| 147 |
+
mu0 = np.zeros((2, Ly0, Lx0))
|
| 148 |
+
mu0[:, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
|
| 149 |
+
else:
|
| 150 |
+
# no masks, return empty flows
|
| 151 |
+
mu0 = np.zeros((2, masks.shape[0], masks.shape[1]))
|
| 152 |
+
return mu0
|
| 153 |
+
|
| 154 |
+
def masks_to_flows_gpu_3d(masks, device=None, niter=None):
|
| 155 |
+
"""Convert masks to flows using diffusion from center pixel.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
|
| 159 |
+
device (torch.device, optional): The device to run the computation on. Defaults to None.
|
| 160 |
+
niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
|
| 164 |
+
|
| 165 |
+
"""
|
| 166 |
+
if device is None:
|
| 167 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
|
| 168 |
+
|
| 169 |
+
Lz0, Ly0, Lx0 = masks.shape
|
| 170 |
+
Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2
|
| 171 |
+
|
| 172 |
+
masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
|
| 173 |
+
masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1))
|
| 174 |
+
|
| 175 |
+
# get mask pixel neighbors
|
| 176 |
+
z, y, x = torch.nonzero(masks_padded).T
|
| 177 |
+
neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z))
|
| 178 |
+
neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0)
|
| 179 |
+
neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0)
|
| 180 |
+
|
| 181 |
+
neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0)
|
| 182 |
+
|
| 183 |
+
# get mask centers
|
| 184 |
+
slices = find_objects(masks)
|
| 185 |
+
|
| 186 |
+
centers = np.zeros((masks.max(), 3), "int")
|
| 187 |
+
for i, si in enumerate(slices):
|
| 188 |
+
if si is not None:
|
| 189 |
+
sz, sy, sx = si
|
| 190 |
+
zi, yi, xi = np.nonzero(masks[sz, sy, sx] == (i + 1))
|
| 191 |
+
zi = zi.astype(np.int32) + 1 # add padding
|
| 192 |
+
yi = yi.astype(np.int32) + 1 # add padding
|
| 193 |
+
xi = xi.astype(np.int32) + 1 # add padding
|
| 194 |
+
zmed = np.mean(zi)
|
| 195 |
+
ymed = np.mean(yi)
|
| 196 |
+
xmed = np.mean(xi)
|
| 197 |
+
imin = np.argmin((zi - zmed)**2 + (xi - xmed)**2 + (yi - ymed)**2)
|
| 198 |
+
zmed = zi[imin]
|
| 199 |
+
ymed = yi[imin]
|
| 200 |
+
xmed = xi[imin]
|
| 201 |
+
centers[i, 0] = zmed + sz.start
|
| 202 |
+
centers[i, 1] = ymed + sy.start
|
| 203 |
+
centers[i, 2] = xmed + sx.start
|
| 204 |
+
|
| 205 |
+
# get neighbor validator (not all neighbors are in same mask)
|
| 206 |
+
neighbor_masks = masks_padded[tuple(neighbors)]
|
| 207 |
+
isneighbor = neighbor_masks == neighbor_masks[0]
|
| 208 |
+
ext = np.array(
|
| 209 |
+
[[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1]
|
| 210 |
+
for sz, sy, sx in slices])
|
| 211 |
+
n_iter = 6 * (ext.sum(axis=1)).max() if niter is None else niter
|
| 212 |
+
|
| 213 |
+
# run diffusion
|
| 214 |
+
shape = masks_padded.shape
|
| 215 |
+
mu = _extend_centers_gpu(neighbors, centers, isneighbor, shape, n_iter=n_iter,
|
| 216 |
+
device=device)
|
| 217 |
+
# normalize
|
| 218 |
+
mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
|
| 219 |
+
|
| 220 |
+
# put into original image
|
| 221 |
+
mu0 = np.zeros((3, Lz0, Ly0, Lx0))
|
| 222 |
+
mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
|
| 223 |
+
return mu0
|
| 224 |
+
|
| 225 |
+
def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None,
|
| 226 |
+
return_flows=True):
|
| 227 |
+
"""Converts labels (list of masks or flows) to flows for training model.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
labels (list of ND-arrays): The labels to convert. labels[k] can be 2D or 3D. If [3 x Ly x Lx],
|
| 231 |
+
it is assumed that flows were precomputed. Otherwise, labels[k][0] or labels[k] (if 2D)
|
| 232 |
+
is used to create flows and cell probabilities.
|
| 233 |
+
files (list of str, optional): The files to save the flows to. If provided, flows are saved to
|
| 234 |
+
files to be reused. Defaults to None.
|
| 235 |
+
device (str, optional): The device to use for computation. Defaults to None.
|
| 236 |
+
redo_flows (bool, optional): Whether to recompute the flows. Defaults to False.
|
| 237 |
+
niter (int, optional): The number of iterations for computing flows. Defaults to None.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
list of [4 x Ly x Lx] arrays: The flows for training the model. flows[k][0] is labels[k],
|
| 241 |
+
flows[k][1] is cell distance transform, flows[k][2] is Y flow, flows[k][3] is X flow,
|
| 242 |
+
and flows[k][4] is heat distribution.
|
| 243 |
+
"""
|
| 244 |
+
nimg = len(labels)
|
| 245 |
+
if labels[0].ndim < 3:
|
| 246 |
+
labels = [labels[n][np.newaxis, :, :] for n in range(nimg)]
|
| 247 |
+
|
| 248 |
+
flows = []
|
| 249 |
+
# flows need to be recomputed
|
| 250 |
+
if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows:
|
| 251 |
+
dynamics_logger.info("computing flows for labels")
|
| 252 |
+
|
| 253 |
+
# compute flows; labels are fixed here to be unique, so they need to be passed back
|
| 254 |
+
# make sure labels are unique!
|
| 255 |
+
labels = [fastremap.renumber(label, in_place=True)[0] for label in labels]
|
| 256 |
+
iterator = trange if nimg > 1 else range
|
| 257 |
+
for n in iterator(nimg):
|
| 258 |
+
labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0]
|
| 259 |
+
vecn = masks_to_flows_gpu(labels[n][0].astype(int), device=device, niter=niter)
|
| 260 |
+
|
| 261 |
+
# concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
|
| 262 |
+
flow = np.concatenate((labels[n], labels[n] > 0.5, vecn),
|
| 263 |
+
axis=0).astype(np.float32)
|
| 264 |
+
if files is not None:
|
| 265 |
+
file_name = os.path.splitext(files[n])[0]
|
| 266 |
+
tifffile.imwrite(file_name + "_flows.tif", flow)
|
| 267 |
+
if return_flows:
|
| 268 |
+
flows.append(flow)
|
| 269 |
+
else:
|
| 270 |
+
dynamics_logger.info("flows precomputed")
|
| 271 |
+
if return_flows:
|
| 272 |
+
flows = [labels[n].astype(np.float32) for n in range(nimg)]
|
| 273 |
+
return flows
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def flow_error(maski, dP_net, device=None):
|
| 277 |
+
"""Error in flows from predicted masks vs flows predicted by network run on image.
|
| 278 |
+
|
| 279 |
+
This function serves to benchmark the quality of masks. It works as follows:
|
| 280 |
+
1. The predicted masks are used to create a flow diagram.
|
| 281 |
+
2. The mask-flows are compared to the flows that the network predicted.
|
| 282 |
+
|
| 283 |
+
If there is a discrepancy between the flows, it suggests that the mask is incorrect.
|
| 284 |
+
Masks with flow_errors greater than 0.4 are discarded by default. This setting can be
|
| 285 |
+
changed in Cellpose.eval or CellposeModel.eval.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
maski (np.ndarray, int): Masks produced from running dynamics on dP_net, where 0=NO masks; 1,2... are mask labels.
|
| 289 |
+
dP_net (np.ndarray, float): ND flows where dP_net.shape[1:] = maski.shape.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
A tuple containing (flow_errors, dP_masks): flow_errors (np.ndarray, float): Mean squared error between predicted flows and flows from masks;
|
| 293 |
+
dP_masks (np.ndarray, float): ND flows produced from the predicted masks.
|
| 294 |
+
"""
|
| 295 |
+
if dP_net.shape[1:] != maski.shape:
|
| 296 |
+
print("ERROR: net flow is not same size as predicted masks")
|
| 297 |
+
return
|
| 298 |
+
|
| 299 |
+
# flows predicted from estimated masks
|
| 300 |
+
dP_masks = masks_to_flows_gpu(maski, device=device)
|
| 301 |
+
# difference between predicted flows vs mask flows
|
| 302 |
+
flow_errors = np.zeros(maski.max())
|
| 303 |
+
for i in range(dP_masks.shape[0]):
|
| 304 |
+
flow_errors += mean((dP_masks[i] - dP_net[i] / 5.)**2, maski,
|
| 305 |
+
index=np.arange(1,
|
| 306 |
+
maski.max() + 1))
|
| 307 |
+
|
| 308 |
+
return flow_errors, dP_masks
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def steps_interp(dP, inds, niter, device=torch.device("cpu")):
|
| 312 |
+
""" Run dynamics of pixels to recover masks in 2D/3D, with interpolation between pixel values.
|
| 313 |
+
|
| 314 |
+
Euler integration of dynamics dP for niter steps.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
p (numpy.ndarray): Array of shape (n_points, 2 or 3) representing the initial pixel locations.
|
| 318 |
+
dP (numpy.ndarray): Array of shape (2, Ly, Lx) or (3, Lz, Ly, Lx) representing the flow field.
|
| 319 |
+
niter (int): Number of iterations to perform.
|
| 320 |
+
device (torch.device, optional): Device to use for computation. Defaults to None.
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
numpy.ndarray: Array of shape (n_points, 2) or (n_points, 3) representing the final pixel locations.
|
| 324 |
+
|
| 325 |
+
Raises:
|
| 326 |
+
None
|
| 327 |
+
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
shape = dP.shape[1:]
|
| 331 |
+
ndim = len(shape)
|
| 332 |
+
|
| 333 |
+
pt = torch.zeros((*[1]*ndim, len(inds[0]), ndim), dtype=torch.float32, device=device)
|
| 334 |
+
im = torch.zeros((1, ndim, *shape), dtype=torch.float32, device=device)
|
| 335 |
+
# Y and X dimensions, flipped X-1, Y-1
|
| 336 |
+
# pt is [1 1 1 3 n_points]
|
| 337 |
+
for n in range(ndim):
|
| 338 |
+
if ndim==3:
|
| 339 |
+
pt[0, 0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32)
|
| 340 |
+
else:
|
| 341 |
+
pt[0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32)
|
| 342 |
+
im[0, ndim - n - 1] = torch.from_numpy(dP[n]).to(device, dtype=torch.float32)
|
| 343 |
+
shape = np.array(shape)[::-1].astype("float") - 1
|
| 344 |
+
|
| 345 |
+
# normalize pt between 0 and 1, normalize the flow
|
| 346 |
+
for k in range(ndim):
|
| 347 |
+
im[:, k] *= 2. / shape[k]
|
| 348 |
+
pt[..., k] /= shape[k]
|
| 349 |
+
|
| 350 |
+
# normalize to between -1 and 1
|
| 351 |
+
pt *= 2
|
| 352 |
+
pt -= 1
|
| 353 |
+
|
| 354 |
+
# dynamics
|
| 355 |
+
for t in range(niter):
|
| 356 |
+
dPt = torch.nn.functional.grid_sample(im, pt, align_corners=False)
|
| 357 |
+
for k in range(ndim): #clamp the final pixel locations
|
| 358 |
+
pt[..., k] = torch.clamp(pt[..., k] + dPt[:, k], -1., 1.)
|
| 359 |
+
|
| 360 |
+
#undo the normalization from before, reverse order of operations
|
| 361 |
+
pt += 1
|
| 362 |
+
pt *= 0.5
|
| 363 |
+
for k in range(ndim):
|
| 364 |
+
pt[..., k] *= shape[k]
|
| 365 |
+
|
| 366 |
+
if ndim==3:
|
| 367 |
+
pt = pt[..., [2, 1, 0]].squeeze()
|
| 368 |
+
pt = pt.unsqueeze(0) if pt.ndim==1 else pt
|
| 369 |
+
return pt.T
|
| 370 |
+
else:
|
| 371 |
+
pt = pt[..., [1, 0]].squeeze()
|
| 372 |
+
pt = pt.unsqueeze(0) if pt.ndim==1 else pt
|
| 373 |
+
return pt.T
|
| 374 |
+
|
| 375 |
+
def follow_flows(dP, inds, niter=200, device=torch.device("cpu")):
|
| 376 |
+
""" Run dynamics to recover masks in 2D or 3D.
|
| 377 |
+
|
| 378 |
+
Pixels are represented as a meshgrid. Only pixels with non-zero cell-probability
|
| 379 |
+
are used (as defined by inds).
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
dP (np.ndarray): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
|
| 383 |
+
mask (np.ndarray, optional): Pixel mask to seed masks. Useful when flows have low magnitudes.
|
| 384 |
+
niter (int, optional): Number of iterations of dynamics to run. Default is 200.
|
| 385 |
+
interp (bool, optional): Interpolate during 2D dynamics (not available in 3D). Default is True.
|
| 386 |
+
device (torch.device, optional): Device to use for computation. Default is None.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
A tuple containing (p, inds): p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx];
|
| 390 |
+
inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
|
| 391 |
+
"""
|
| 392 |
+
shape = np.array(dP.shape[1:]).astype(np.int32)
|
| 393 |
+
ndim = len(inds)
|
| 394 |
+
|
| 395 |
+
p = steps_interp(dP, inds, niter, device=device)
|
| 396 |
+
|
| 397 |
+
return p
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")):
|
| 401 |
+
"""Remove masks which have inconsistent flows.
|
| 402 |
+
|
| 403 |
+
Uses metrics.flow_error to compute flows from predicted masks
|
| 404 |
+
and compare flows to predicted flows from the network. Discards
|
| 405 |
+
masks with flow errors greater than the threshold.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
masks (int, 2D or 3D array): Labelled masks, 0=NO masks; 1,2,...=mask labels,
|
| 409 |
+
size [Ly x Lx] or [Lz x Ly x Lx].
|
| 410 |
+
flows (float, 3D or 4D array): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
|
| 411 |
+
threshold (float, optional): Masks with flow error greater than threshold are discarded.
|
| 412 |
+
Default is 0.4.
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
masks (int, 2D or 3D array): Masks with inconsistent flow masks removed,
|
| 416 |
+
0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
|
| 417 |
+
"""
|
| 418 |
+
device0 = device
|
| 419 |
+
if masks.size > 10000 * 10000 and (device is not None and device.type == "cuda"):
|
| 420 |
+
|
| 421 |
+
major_version, minor_version = torch.__version__.split(".")[:2]
|
| 422 |
+
torch.cuda.empty_cache()
|
| 423 |
+
if major_version == "1" and int(minor_version) < 10:
|
| 424 |
+
# for PyTorch version lower than 1.10
|
| 425 |
+
def mem_info():
|
| 426 |
+
total_mem = torch.cuda.get_device_properties(device0.index).total_memory
|
| 427 |
+
used_mem = torch.cuda.memory_allocated(device0.index)
|
| 428 |
+
free_mem = total_mem - used_mem
|
| 429 |
+
return total_mem, free_mem
|
| 430 |
+
else:
|
| 431 |
+
# for PyTorch version 1.10 and above
|
| 432 |
+
def mem_info():
|
| 433 |
+
free_mem, total_mem = torch.cuda.mem_get_info(device0.index)
|
| 434 |
+
return total_mem, free_mem
|
| 435 |
+
total_mem, free_mem = mem_info()
|
| 436 |
+
if masks.size * 32 > free_mem:
|
| 437 |
+
dynamics_logger.warning(
|
| 438 |
+
"WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold"
|
| 439 |
+
)
|
| 440 |
+
dynamics_logger.info("turn off QC step with flow_threshold=0 if too slow")
|
| 441 |
+
device0 = torch.device("cpu")
|
| 442 |
+
|
| 443 |
+
merrors, _ = flow_error(masks, flows, device0)
|
| 444 |
+
badi = 1 + (merrors > threshold).nonzero()[0]
|
| 445 |
+
masks[np.isin(masks, badi)] = 0
|
| 446 |
+
return masks
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def max_pool1d(h, kernel_size=5, axis=1, out=None):
|
| 450 |
+
""" memory efficient max_pool thanks to Mark Kittisopikul
|
| 451 |
+
|
| 452 |
+
for stride=1, padding=kernel_size//2, requires odd kernel_size >= 3
|
| 453 |
+
|
| 454 |
+
"""
|
| 455 |
+
if out is None:
|
| 456 |
+
out = h.clone()
|
| 457 |
+
else:
|
| 458 |
+
out.copy_(h)
|
| 459 |
+
|
| 460 |
+
nd = h.shape[axis]
|
| 461 |
+
k0 = kernel_size // 2
|
| 462 |
+
for d in range(-k0, k0+1):
|
| 463 |
+
if axis==1:
|
| 464 |
+
mv = out[:, max(-d,0):min(nd-d,nd)]
|
| 465 |
+
hv = h[:, max(d,0):min(nd+d,nd)]
|
| 466 |
+
elif axis==2:
|
| 467 |
+
mv = out[:, :, max(-d,0):min(nd-d,nd)]
|
| 468 |
+
hv = h[:, :, max(d,0):min(nd+d,nd)]
|
| 469 |
+
elif axis==3:
|
| 470 |
+
mv = out[:, :, :, max(-d,0):min(nd-d,nd)]
|
| 471 |
+
hv = h[:, :, :, max(d,0):min(nd+d,nd)]
|
| 472 |
+
torch.maximum(mv, hv, out=mv)
|
| 473 |
+
return out
|
| 474 |
+
|
| 475 |
+
def max_pool_nd(h, kernel_size=5):
|
| 476 |
+
""" memory efficient max_pool in 2d or 3d """
|
| 477 |
+
ndim = h.ndim - 1
|
| 478 |
+
hmax = max_pool1d(h, kernel_size=kernel_size, axis=1)
|
| 479 |
+
hmax2 = max_pool1d(hmax, kernel_size=kernel_size, axis=2)
|
| 480 |
+
if ndim==2:
|
| 481 |
+
del hmax
|
| 482 |
+
return hmax2
|
| 483 |
+
else:
|
| 484 |
+
hmax = max_pool1d(hmax2, kernel_size=kernel_size, axis=3, out=hmax)
|
| 485 |
+
del hmax2
|
| 486 |
+
return hmax
|
| 487 |
+
|
| 488 |
+
def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
|
| 489 |
+
"""Create masks using pixel convergence after running dynamics.
|
| 490 |
+
|
| 491 |
+
Makes a histogram of final pixel locations p, initializes masks
|
| 492 |
+
at peaks of histogram and extends the masks from the peaks so that
|
| 493 |
+
they include all pixels with more than 2 final pixels p. Discards
|
| 494 |
+
masks with flow errors greater than the threshold.
|
| 495 |
+
|
| 496 |
+
Parameters:
|
| 497 |
+
p (float32, 3D or 4D array): Final locations of each pixel after dynamics,
|
| 498 |
+
size [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
|
| 499 |
+
iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are
|
| 500 |
+
iscell False to stay in their original location.
|
| 501 |
+
rpad (int, optional): Histogram edge padding. Default is 20.
|
| 502 |
+
max_size_fraction (float, optional): Masks larger than max_size_fraction of
|
| 503 |
+
total image size are removed. Default is 0.4.
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed,
|
| 507 |
+
0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
ndim = len(shape0)
|
| 511 |
+
device = pt.device
|
| 512 |
+
|
| 513 |
+
rpad = 20
|
| 514 |
+
pt += rpad
|
| 515 |
+
pt = torch.clamp(pt, min=0)
|
| 516 |
+
for i in range(len(pt)):
|
| 517 |
+
pt[i] = torch.clamp(pt[i], max=shape0[i]+rpad-1)
|
| 518 |
+
|
| 519 |
+
# # add extra padding to make divisible by 5
|
| 520 |
+
# shape = tuple((np.ceil((shape0 + 2*rpad)/5) * 5).astype(int))
|
| 521 |
+
shape = tuple(np.array(shape0) + 2*rpad)
|
| 522 |
+
|
| 523 |
+
# sparse coo torch
|
| 524 |
+
coo = torch.sparse_coo_tensor(pt, torch.ones(pt.shape[1], device=pt.device, dtype=torch.int),
|
| 525 |
+
shape)
|
| 526 |
+
h1 = coo.to_dense()
|
| 527 |
+
del coo
|
| 528 |
+
|
| 529 |
+
hmax1 = max_pool_nd(h1.unsqueeze(0), kernel_size=5)
|
| 530 |
+
hmax1 = hmax1.squeeze()
|
| 531 |
+
seeds1 = torch.nonzero((h1 - hmax1 > -1e-6) * (h1 > 10))
|
| 532 |
+
del hmax1
|
| 533 |
+
if len(seeds1) == 0:
|
| 534 |
+
dynamics_logger.warning("no seeds found in get_masks_torch - no masks found.")
|
| 535 |
+
return np.zeros(shape0, dtype="uint16")
|
| 536 |
+
|
| 537 |
+
npts = h1[tuple(seeds1.T)]
|
| 538 |
+
isort1 = npts.argsort()
|
| 539 |
+
seeds1 = seeds1[isort1]
|
| 540 |
+
|
| 541 |
+
n_seeds = len(seeds1)
|
| 542 |
+
h_slc = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
|
| 543 |
+
for k in range(n_seeds):
|
| 544 |
+
slc = tuple([slice(seeds1[k][j]-5, seeds1[k][j]+6) for j in range(ndim)])
|
| 545 |
+
h_slc[k] = h1[slc]
|
| 546 |
+
del h1
|
| 547 |
+
seed_masks = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
|
| 548 |
+
if ndim==2:
|
| 549 |
+
seed_masks[:,5,5] = 1
|
| 550 |
+
else:
|
| 551 |
+
seed_masks[:,5,5,5] = 1
|
| 552 |
+
|
| 553 |
+
for iter in range(5):
|
| 554 |
+
# extend
|
| 555 |
+
seed_masks = max_pool_nd(seed_masks, kernel_size=3)
|
| 556 |
+
seed_masks *= h_slc > 2
|
| 557 |
+
del h_slc
|
| 558 |
+
seeds_new = [tuple((torch.nonzero(seed_masks[k]) + seeds1[k] - 5).T)
|
| 559 |
+
for k in range(n_seeds)]
|
| 560 |
+
del seed_masks
|
| 561 |
+
|
| 562 |
+
dtype = torch.int32 if n_seeds < 2**16 else torch.int64
|
| 563 |
+
M1 = torch.zeros(shape, dtype=dtype, device=device)
|
| 564 |
+
for k in range(n_seeds):
|
| 565 |
+
M1[seeds_new[k]] = 1 + k
|
| 566 |
+
|
| 567 |
+
M1 = M1[tuple(pt)]
|
| 568 |
+
M1 = M1.cpu().numpy()
|
| 569 |
+
|
| 570 |
+
dtype = "uint16" if n_seeds < 2**16 else "uint32"
|
| 571 |
+
M0 = np.zeros(shape0, dtype=dtype)
|
| 572 |
+
M0[inds] = M1
|
| 573 |
+
|
| 574 |
+
# remove big masks
|
| 575 |
+
uniq, counts = fastremap.unique(M0, return_counts=True)
|
| 576 |
+
big = np.prod(shape0) * max_size_fraction
|
| 577 |
+
bigc = uniq[counts > big]
|
| 578 |
+
if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0):
|
| 579 |
+
M0 = fastremap.mask(M0, bigc)
|
| 580 |
+
fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
|
| 581 |
+
M0 = M0.reshape(tuple(shape0))
|
| 582 |
+
|
| 583 |
+
#print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9:.3f} gb")
|
| 584 |
+
return M0
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def resize_and_compute_masks(dP, cellprob, niter=200, cellprob_threshold=0.0,
|
| 588 |
+
flow_threshold=0.4, do_3D=False, min_size=15,
|
| 589 |
+
max_size_fraction=0.4, resize=None, device=torch.device("cpu")):
|
| 590 |
+
"""Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None.
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
dP (numpy.ndarray): The dynamics flow field array.
|
| 594 |
+
cellprob (numpy.ndarray): The cell probability array.
|
| 595 |
+
p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None
|
| 596 |
+
niter (int, optional): The number of iterations for mask computation. Defaults to 200.
|
| 597 |
+
cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0.
|
| 598 |
+
flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4.
|
| 599 |
+
interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
|
| 600 |
+
do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
|
| 601 |
+
min_size (int, optional): The minimum size of the masks. Defaults to 15.
|
| 602 |
+
max_size_fraction (float, optional): Masks larger than max_size_fraction of
|
| 603 |
+
total image size are removed. Default is 0.4.
|
| 604 |
+
resize (tuple, optional): The desired size for resizing the masks. Defaults to None.
|
| 605 |
+
device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu").
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
tuple: A tuple containing the computed masks and the final pixel locations.
|
| 609 |
+
"""
|
| 610 |
+
mask = compute_masks(dP, cellprob, niter=niter,
|
| 611 |
+
cellprob_threshold=cellprob_threshold,
|
| 612 |
+
flow_threshold=flow_threshold, do_3D=do_3D,
|
| 613 |
+
max_size_fraction=max_size_fraction,
|
| 614 |
+
device=device)
|
| 615 |
+
|
| 616 |
+
if resize is not None:
|
| 617 |
+
dynamics_logger.warning("Resizing is depricated in v4.0.1+")
|
| 618 |
+
|
| 619 |
+
mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
|
| 620 |
+
|
| 621 |
+
return mask
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
|
| 625 |
+
flow_threshold=0.4, do_3D=False, min_size=-1,
|
| 626 |
+
max_size_fraction=0.4, device=torch.device("cpu")):
|
| 627 |
+
"""Compute masks using dynamics from dP and cellprob.
|
| 628 |
+
|
| 629 |
+
Args:
|
| 630 |
+
dP (numpy.ndarray): The dynamics flow field array.
|
| 631 |
+
cellprob (numpy.ndarray): The cell probability array.
|
| 632 |
+
p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None
|
| 633 |
+
niter (int, optional): The number of iterations for mask computation. Defaults to 200.
|
| 634 |
+
cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0.
|
| 635 |
+
flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4.
|
| 636 |
+
interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
|
| 637 |
+
do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
|
| 638 |
+
min_size (int, optional): The minimum size of the masks. Defaults to 15.
|
| 639 |
+
max_size_fraction (float, optional): Masks larger than max_size_fraction of
|
| 640 |
+
total image size are removed. Default is 0.4.
|
| 641 |
+
device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu").
|
| 642 |
+
|
| 643 |
+
Returns:
|
| 644 |
+
tuple: A tuple containing the computed masks and the final pixel locations.
|
| 645 |
+
"""
|
| 646 |
+
|
| 647 |
+
if (cellprob > cellprob_threshold).sum(): #mask at this point is a cell cluster binary map, not labels
|
| 648 |
+
inds = np.nonzero(cellprob > cellprob_threshold)
|
| 649 |
+
if len(inds[0]) == 0:
|
| 650 |
+
dynamics_logger.info("No cell pixels found.")
|
| 651 |
+
shape = cellprob.shape
|
| 652 |
+
mask = np.zeros(shape, "uint16")
|
| 653 |
+
return mask
|
| 654 |
+
|
| 655 |
+
p_final = follow_flows(dP * (cellprob > cellprob_threshold) / 5.,
|
| 656 |
+
inds=inds, niter=niter,
|
| 657 |
+
device=device)
|
| 658 |
+
if not torch.is_tensor(p_final):
|
| 659 |
+
p_final = torch.from_numpy(p_final).to(device, dtype=torch.int)
|
| 660 |
+
else:
|
| 661 |
+
p_final = p_final.int()
|
| 662 |
+
# calculate masks
|
| 663 |
+
if device.type == "mps":
|
| 664 |
+
p_final = p_final.to(torch.device("cpu"))
|
| 665 |
+
mask = get_masks_torch(p_final, inds, dP.shape[1:],
|
| 666 |
+
max_size_fraction=max_size_fraction)
|
| 667 |
+
del p_final
|
| 668 |
+
# flow thresholding factored out of get_masks
|
| 669 |
+
if not do_3D:
|
| 670 |
+
if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0:
|
| 671 |
+
# make sure labels are unique at output of get_masks
|
| 672 |
+
mask = remove_bad_flow_masks(mask, dP, threshold=flow_threshold,
|
| 673 |
+
device=device)
|
| 674 |
+
|
| 675 |
+
if mask.max() < 2**16 and mask.dtype != "uint16":
|
| 676 |
+
mask = mask.astype("uint16")
|
| 677 |
+
|
| 678 |
+
else: # nothing to compute, just make it compatible
|
| 679 |
+
dynamics_logger.info("No cell pixels found.")
|
| 680 |
+
shape = cellprob.shape
|
| 681 |
+
mask = np.zeros(cellprob.shape, "uint16")
|
| 682 |
+
return mask
|
| 683 |
+
|
| 684 |
+
if min_size > 0:
|
| 685 |
+
mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
|
| 686 |
+
|
| 687 |
+
if mask.dtype == np.uint32:
|
| 688 |
+
dynamics_logger.warning(
|
| 689 |
+
"more than 65535 masks in image, masks returned as np.uint32")
|
| 690 |
+
|
| 691 |
+
return mask
|
models/seg_post_model/cellpose/export.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Auxiliary module for bioimageio format export
|
| 2 |
+
|
| 3 |
+
Example usage:
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
#!/bin/bash
|
| 7 |
+
|
| 8 |
+
# Define default paths and parameters
|
| 9 |
+
DEFAULT_CHANNELS="1 0"
|
| 10 |
+
DEFAULT_PATH_PRETRAINED_MODEL="/home/qinyu/models/cp/cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995"
|
| 11 |
+
DEFAULT_PATH_README="/home/qinyu/models/cp/README.md"
|
| 12 |
+
DEFAULT_LIST_PATH_COVER_IMAGES="/home/qinyu/images/cp/cellpose_raw_and_segmentation.jpg /home/qinyu/images/cp/cellpose_raw_and_probability.jpg /home/qinyu/images/cp/cellpose_raw.jpg"
|
| 13 |
+
DEFAULT_MODEL_ID="philosophical-panda"
|
| 14 |
+
DEFAULT_MODEL_ICON="🐼"
|
| 15 |
+
DEFAULT_MODEL_VERSION="0.1.0"
|
| 16 |
+
DEFAULT_MODEL_NAME="My Cool Cellpose"
|
| 17 |
+
DEFAULT_MODEL_DOCUMENTATION="A cool Cellpose model trained for my cool dataset."
|
| 18 |
+
DEFAULT_MODEL_AUTHORS='[{"name": "Qin Yu", "affiliation": "EMBL", "github_user": "qin-yu", "orcid": "0000-0002-4652-0795"}]'
|
| 19 |
+
DEFAULT_MODEL_CITE='[{"text": "For more details of the model itself, see the manuscript", "doi": "10.1242/dev.202800", "url": null}]'
|
| 20 |
+
DEFAULT_MODEL_TAGS="cellpose 3d 2d"
|
| 21 |
+
DEFAULT_MODEL_LICENSE="MIT"
|
| 22 |
+
DEFAULT_MODEL_REPO="https://github.com/kreshuklab/go-nuclear"
|
| 23 |
+
|
| 24 |
+
# Run the Python script with default parameters
|
| 25 |
+
python export.py \
|
| 26 |
+
--channels $DEFAULT_CHANNELS \
|
| 27 |
+
--path_pretrained_model "$DEFAULT_PATH_PRETRAINED_MODEL" \
|
| 28 |
+
--path_readme "$DEFAULT_PATH_README" \
|
| 29 |
+
--list_path_cover_images $DEFAULT_LIST_PATH_COVER_IMAGES \
|
| 30 |
+
--model_version "$DEFAULT_MODEL_VERSION" \
|
| 31 |
+
--model_name "$DEFAULT_MODEL_NAME" \
|
| 32 |
+
--model_documentation "$DEFAULT_MODEL_DOCUMENTATION" \
|
| 33 |
+
--model_authors "$DEFAULT_MODEL_AUTHORS" \
|
| 34 |
+
--model_cite "$DEFAULT_MODEL_CITE" \
|
| 35 |
+
--model_tags $DEFAULT_MODEL_TAGS \
|
| 36 |
+
--model_license "$DEFAULT_MODEL_LICENSE" \
|
| 37 |
+
--model_repo "$DEFAULT_MODEL_REPO"
|
| 38 |
+
```
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
import os
|
| 42 |
+
import sys
|
| 43 |
+
import json
|
| 44 |
+
import argparse
|
| 45 |
+
from pathlib import Path
|
| 46 |
+
from urllib.parse import urlparse
|
| 47 |
+
|
| 48 |
+
import torch
|
| 49 |
+
import numpy as np
|
| 50 |
+
|
| 51 |
+
from cellpose.io import imread
|
| 52 |
+
from cellpose.utils import download_url_to_file
|
| 53 |
+
from cellpose.transforms import pad_image_ND, normalize_img, convert_image
|
| 54 |
+
from cellpose.vit_sam import CPnetBioImageIO
|
| 55 |
+
|
| 56 |
+
from bioimageio.spec.model.v0_5 import (
|
| 57 |
+
ArchitectureFromFileDescr,
|
| 58 |
+
Author,
|
| 59 |
+
AxisId,
|
| 60 |
+
ChannelAxis,
|
| 61 |
+
CiteEntry,
|
| 62 |
+
Doi,
|
| 63 |
+
FileDescr,
|
| 64 |
+
Identifier,
|
| 65 |
+
InputTensorDescr,
|
| 66 |
+
IntervalOrRatioDataDescr,
|
| 67 |
+
LicenseId,
|
| 68 |
+
ModelDescr,
|
| 69 |
+
ModelId,
|
| 70 |
+
OrcidId,
|
| 71 |
+
OutputTensorDescr,
|
| 72 |
+
ParameterizedSize,
|
| 73 |
+
PytorchStateDictWeightsDescr,
|
| 74 |
+
SizeReference,
|
| 75 |
+
SpaceInputAxis,
|
| 76 |
+
SpaceOutputAxis,
|
| 77 |
+
TensorId,
|
| 78 |
+
TorchscriptWeightsDescr,
|
| 79 |
+
Version,
|
| 80 |
+
WeightsDescr,
|
| 81 |
+
)
|
| 82 |
+
# Define ARBITRARY_SIZE if it is not available in the module
|
| 83 |
+
try:
|
| 84 |
+
from bioimageio.spec.model.v0_5 import ARBITRARY_SIZE
|
| 85 |
+
except ImportError:
|
| 86 |
+
ARBITRARY_SIZE = ParameterizedSize(min=1, step=1)
|
| 87 |
+
|
| 88 |
+
from bioimageio.spec.common import HttpUrl
|
| 89 |
+
from bioimageio.spec import save_bioimageio_package
|
| 90 |
+
from bioimageio.core import test_model
|
| 91 |
+
|
| 92 |
+
DEFAULT_CHANNELS = [2, 1]
|
| 93 |
+
DEFAULT_NORMALIZE_PARAMS = {
|
| 94 |
+
"axis": -1,
|
| 95 |
+
"lowhigh": None,
|
| 96 |
+
"percentile": None,
|
| 97 |
+
"normalize": True,
|
| 98 |
+
"norm3D": False,
|
| 99 |
+
"sharpen_radius": 0,
|
| 100 |
+
"smooth_radius": 0,
|
| 101 |
+
"tile_norm_blocksize": 0,
|
| 102 |
+
"tile_norm_smooth3D": 1,
|
| 103 |
+
"invert": False,
|
| 104 |
+
}
|
| 105 |
+
IMAGE_URL = "http://www.cellpose.org/static/data/rgb_3D.tif"
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def download_and_normalize_image(path_dir_temp, channels=DEFAULT_CHANNELS):
|
| 109 |
+
"""
|
| 110 |
+
Download and normalize image.
|
| 111 |
+
"""
|
| 112 |
+
filename = os.path.basename(urlparse(IMAGE_URL).path)
|
| 113 |
+
path_image = path_dir_temp / filename
|
| 114 |
+
if not path_image.exists():
|
| 115 |
+
sys.stderr.write(f'Downloading: "{IMAGE_URL}" to {path_image}\n')
|
| 116 |
+
download_url_to_file(IMAGE_URL, path_image)
|
| 117 |
+
img = imread(path_image).astype(np.float32)
|
| 118 |
+
img = convert_image(img, channels, channel_axis=1, z_axis=0, do_3D=False, nchan=2)
|
| 119 |
+
img = normalize_img(img, **DEFAULT_NORMALIZE_PARAMS)
|
| 120 |
+
img = np.transpose(img, (0, 3, 1, 2))
|
| 121 |
+
img, _, _ = pad_image_ND(img)
|
| 122 |
+
return img
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def load_bioimageio_cpnet_model(path_model_weight, nchan=2):
|
| 126 |
+
cpnet_kwargs = {
|
| 127 |
+
"nout": 3,
|
| 128 |
+
}
|
| 129 |
+
cpnet_biio = CPnetBioImageIO(**cpnet_kwargs)
|
| 130 |
+
state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True)
|
| 131 |
+
cpnet_biio.load_state_dict(state_dict_cuda)
|
| 132 |
+
cpnet_biio.eval() # crucial for the prediction results
|
| 133 |
+
return cpnet_biio, cpnet_kwargs
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def descr_gen_input(path_test_input, nchan=2):
|
| 137 |
+
input_axes = [
|
| 138 |
+
SpaceInputAxis(id=AxisId("z"), size=ARBITRARY_SIZE),
|
| 139 |
+
ChannelAxis(channel_names=[Identifier(f"c{i+1}") for i in range(nchan)]),
|
| 140 |
+
SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=16, step=16)),
|
| 141 |
+
SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=16, step=16)),
|
| 142 |
+
]
|
| 143 |
+
data_descr = IntervalOrRatioDataDescr(type="float32")
|
| 144 |
+
path_test_input = Path(path_test_input)
|
| 145 |
+
descr_input = InputTensorDescr(
|
| 146 |
+
id=TensorId("raw"),
|
| 147 |
+
axes=input_axes,
|
| 148 |
+
test_tensor=FileDescr(source=path_test_input),
|
| 149 |
+
data=data_descr,
|
| 150 |
+
)
|
| 151 |
+
return descr_input
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def descr_gen_output_flow(path_test_output):
|
| 155 |
+
output_axes_output_tensor = [
|
| 156 |
+
SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
|
| 157 |
+
ChannelAxis(channel_names=[Identifier("flow1"), Identifier("flow2"), Identifier("flow3")]),
|
| 158 |
+
SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))),
|
| 159 |
+
SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x"))),
|
| 160 |
+
]
|
| 161 |
+
path_test_output = Path(path_test_output)
|
| 162 |
+
descr_output = OutputTensorDescr(
|
| 163 |
+
id=TensorId("flow"),
|
| 164 |
+
axes=output_axes_output_tensor,
|
| 165 |
+
test_tensor=FileDescr(source=path_test_output),
|
| 166 |
+
)
|
| 167 |
+
return descr_output
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def descr_gen_output_downsampled(path_dir_temp, nbase=None):
|
| 171 |
+
if nbase is None:
|
| 172 |
+
nbase = [32, 64, 128, 256]
|
| 173 |
+
|
| 174 |
+
output_axes_downsampled_tensors = [
|
| 175 |
+
[
|
| 176 |
+
SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
|
| 177 |
+
ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(base)]),
|
| 178 |
+
SpaceOutputAxis(
|
| 179 |
+
id=AxisId("y"),
|
| 180 |
+
size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y")),
|
| 181 |
+
scale=2**offset,
|
| 182 |
+
),
|
| 183 |
+
SpaceOutputAxis(
|
| 184 |
+
id=AxisId("x"),
|
| 185 |
+
size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")),
|
| 186 |
+
scale=2**offset,
|
| 187 |
+
),
|
| 188 |
+
]
|
| 189 |
+
for offset, base in enumerate(nbase)
|
| 190 |
+
]
|
| 191 |
+
path_downsampled_tensors = [
|
| 192 |
+
Path(path_dir_temp / f"test_downsampled_{i}.npy") for i in range(len(output_axes_downsampled_tensors))
|
| 193 |
+
]
|
| 194 |
+
descr_output_downsampled_tensors = [
|
| 195 |
+
OutputTensorDescr(
|
| 196 |
+
id=TensorId(f"downsampled_{i}"),
|
| 197 |
+
axes=axes,
|
| 198 |
+
test_tensor=FileDescr(source=path),
|
| 199 |
+
)
|
| 200 |
+
for i, (axes, path) in enumerate(zip(output_axes_downsampled_tensors, path_downsampled_tensors))
|
| 201 |
+
]
|
| 202 |
+
return descr_output_downsampled_tensors
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def descr_gen_output_style(path_test_style, nchannel=256):
|
| 206 |
+
output_axes_style_tensor = [
|
| 207 |
+
SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
|
| 208 |
+
ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(nchannel)]),
|
| 209 |
+
]
|
| 210 |
+
path_style_tensor = Path(path_test_style)
|
| 211 |
+
descr_output_style_tensor = OutputTensorDescr(
|
| 212 |
+
id=TensorId("style"),
|
| 213 |
+
axes=output_axes_style_tensor,
|
| 214 |
+
test_tensor=FileDescr(source=path_style_tensor),
|
| 215 |
+
)
|
| 216 |
+
return descr_output_style_tensor
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper=None):
|
| 220 |
+
if path_cpnet_wrapper is None:
|
| 221 |
+
path_cpnet_wrapper = Path(__file__).parent / "resnet_torch.py"
|
| 222 |
+
pytorch_architecture = ArchitectureFromFileDescr(
|
| 223 |
+
callable=Identifier("CPnetBioImageIO"),
|
| 224 |
+
source=Path(path_cpnet_wrapper),
|
| 225 |
+
kwargs=cpnet_kwargs,
|
| 226 |
+
)
|
| 227 |
+
return pytorch_architecture
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def descr_gen_documentation(path_doc, markdown_text):
|
| 231 |
+
with open(path_doc, "w") as f:
|
| 232 |
+
f.write(markdown_text)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def package_to_bioimageio(
|
| 236 |
+
path_pretrained_model,
|
| 237 |
+
path_save_trace,
|
| 238 |
+
path_readme,
|
| 239 |
+
list_path_cover_images,
|
| 240 |
+
descr_input,
|
| 241 |
+
descr_output,
|
| 242 |
+
descr_output_downsampled_tensors,
|
| 243 |
+
descr_output_style_tensor,
|
| 244 |
+
pytorch_version,
|
| 245 |
+
pytorch_architecture,
|
| 246 |
+
model_id,
|
| 247 |
+
model_icon,
|
| 248 |
+
model_version,
|
| 249 |
+
model_name,
|
| 250 |
+
model_documentation,
|
| 251 |
+
model_authors,
|
| 252 |
+
model_cite,
|
| 253 |
+
model_tags,
|
| 254 |
+
model_license,
|
| 255 |
+
model_repo,
|
| 256 |
+
):
|
| 257 |
+
"""Package model description to BioImage.IO format."""
|
| 258 |
+
my_model_descr = ModelDescr(
|
| 259 |
+
id=ModelId(model_id) if model_id is not None else None,
|
| 260 |
+
id_emoji=model_icon,
|
| 261 |
+
version=Version(model_version),
|
| 262 |
+
name=model_name,
|
| 263 |
+
description=model_documentation,
|
| 264 |
+
authors=[
|
| 265 |
+
Author(
|
| 266 |
+
name=author["name"],
|
| 267 |
+
affiliation=author["affiliation"],
|
| 268 |
+
github_user=author["github_user"],
|
| 269 |
+
orcid=OrcidId(author["orcid"]),
|
| 270 |
+
)
|
| 271 |
+
for author in model_authors
|
| 272 |
+
],
|
| 273 |
+
cite=[CiteEntry(text=cite["text"], doi=Doi(cite["doi"]), url=cite["url"]) for cite in model_cite],
|
| 274 |
+
covers=[Path(img) for img in list_path_cover_images],
|
| 275 |
+
license=LicenseId(model_license),
|
| 276 |
+
tags=model_tags,
|
| 277 |
+
documentation=Path(path_readme),
|
| 278 |
+
git_repo=HttpUrl(model_repo),
|
| 279 |
+
inputs=[descr_input],
|
| 280 |
+
outputs=[descr_output, descr_output_style_tensor] + descr_output_downsampled_tensors,
|
| 281 |
+
weights=WeightsDescr(
|
| 282 |
+
pytorch_state_dict=PytorchStateDictWeightsDescr(
|
| 283 |
+
source=Path(path_pretrained_model),
|
| 284 |
+
architecture=pytorch_architecture,
|
| 285 |
+
pytorch_version=pytorch_version,
|
| 286 |
+
),
|
| 287 |
+
torchscript=TorchscriptWeightsDescr(
|
| 288 |
+
source=Path(path_save_trace),
|
| 289 |
+
pytorch_version=pytorch_version,
|
| 290 |
+
parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights.
|
| 291 |
+
),
|
| 292 |
+
),
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
return my_model_descr
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def parse_args():
|
| 299 |
+
# fmt: off
|
| 300 |
+
parser = argparse.ArgumentParser(description="BioImage.IO model packaging for Cellpose")
|
| 301 |
+
parser.add_argument("--channels", nargs=2, default=[2, 1], type=int, help="Cyto-only = [2, 0], Cyto + Nuclei = [2, 1], Nuclei-only = [1, 0]")
|
| 302 |
+
parser.add_argument("--path_pretrained_model", required=True, type=str, help="Path to pretrained model file, e.g., cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995")
|
| 303 |
+
parser.add_argument("--path_readme", required=True, type=str, help="Path to README file")
|
| 304 |
+
parser.add_argument("--list_path_cover_images", nargs='+', required=True, type=str, help="List of paths to cover images")
|
| 305 |
+
parser.add_argument("--model_id", type=str, help="Model ID, provide if already exists", default=None)
|
| 306 |
+
parser.add_argument("--model_icon", type=str, help="Model icon, provide if already exists", default=None)
|
| 307 |
+
parser.add_argument("--model_version", required=True, type=str, help="Model version, new model should be 0.1.0")
|
| 308 |
+
parser.add_argument("--model_name", required=True, type=str, help="Model name, e.g., My Cool Cellpose")
|
| 309 |
+
parser.add_argument("--model_documentation", required=True, type=str, help="Model documentation, e.g., A cool Cellpose model trained for my cool dataset.")
|
| 310 |
+
parser.add_argument("--model_authors", required=True, type=str, help="Model authors in JSON format, e.g., '[{\"name\": \"Qin Yu\", \"affiliation\": \"EMBL\", \"github_user\": \"qin-yu\", \"orcid\": \"0000-0002-4652-0795\"}]'")
|
| 311 |
+
parser.add_argument("--model_cite", required=True, type=str, help="Model citation in JSON format, e.g., '[{\"text\": \"For more details of the model itself, see the manuscript\", \"doi\": \"10.1242/dev.202800\", \"url\": null}]'")
|
| 312 |
+
parser.add_argument("--model_tags", nargs='+', required=True, type=str, help="Model tags, e.g., cellpose 3d 2d")
|
| 313 |
+
parser.add_argument("--model_license", required=True, type=str, help="Model license, e.g., MIT")
|
| 314 |
+
parser.add_argument("--model_repo", required=True, type=str, help="Model repository URL")
|
| 315 |
+
return parser.parse_args()
|
| 316 |
+
# fmt: on
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def main():
|
| 320 |
+
args = parse_args()
|
| 321 |
+
|
| 322 |
+
# Parse user-provided paths and arguments
|
| 323 |
+
channels = args.channels
|
| 324 |
+
model_cite = json.loads(args.model_cite)
|
| 325 |
+
model_authors = json.loads(args.model_authors)
|
| 326 |
+
|
| 327 |
+
path_readme = Path(args.path_readme)
|
| 328 |
+
path_pretrained_model = Path(args.path_pretrained_model)
|
| 329 |
+
list_path_cover_images = [Path(path_image) for path_image in args.list_path_cover_images]
|
| 330 |
+
|
| 331 |
+
# Auto-generated paths
|
| 332 |
+
path_cpnet_wrapper = Path(__file__).resolve().parent / "resnet_torch.py"
|
| 333 |
+
path_dir_temp = Path(__file__).resolve().parent.parent / "models" / path_pretrained_model.stem
|
| 334 |
+
path_dir_temp.mkdir(parents=True, exist_ok=True)
|
| 335 |
+
|
| 336 |
+
path_save_trace = path_dir_temp / "cp_traced.pt"
|
| 337 |
+
path_test_input = path_dir_temp / "test_input.npy"
|
| 338 |
+
path_test_output = path_dir_temp / "test_output.npy"
|
| 339 |
+
path_test_style = path_dir_temp / "test_style.npy"
|
| 340 |
+
path_bioimageio_package = path_dir_temp / "cellpose_model.zip"
|
| 341 |
+
|
| 342 |
+
# Download test input image
|
| 343 |
+
img_np = download_and_normalize_image(path_dir_temp, channels=channels)
|
| 344 |
+
np.save(path_test_input, img_np)
|
| 345 |
+
img = torch.tensor(img_np).float()
|
| 346 |
+
|
| 347 |
+
# Load model
|
| 348 |
+
cpnet_biio, cpnet_kwargs = load_bioimageio_cpnet_model(path_pretrained_model)
|
| 349 |
+
|
| 350 |
+
# Test model and save output
|
| 351 |
+
tuple_output_tensor = cpnet_biio(img)
|
| 352 |
+
np.save(path_test_output, tuple_output_tensor[0].detach().numpy())
|
| 353 |
+
np.save(path_test_style, tuple_output_tensor[1].detach().numpy())
|
| 354 |
+
for i, t in enumerate(tuple_output_tensor[2:]):
|
| 355 |
+
np.save(path_dir_temp / f"test_downsampled_{i}.npy", t.detach().numpy())
|
| 356 |
+
|
| 357 |
+
# Save traced model
|
| 358 |
+
model_traced = torch.jit.trace(cpnet_biio, img)
|
| 359 |
+
model_traced.save(path_save_trace)
|
| 360 |
+
|
| 361 |
+
# Generate model description
|
| 362 |
+
descr_input = descr_gen_input(path_test_input)
|
| 363 |
+
descr_output = descr_gen_output_flow(path_test_output)
|
| 364 |
+
descr_output_downsampled_tensors = descr_gen_output_downsampled(path_dir_temp, nbase=cpnet_biio.nbase[1:])
|
| 365 |
+
descr_output_style_tensor = descr_gen_output_style(path_test_style, cpnet_biio.nbase[-1])
|
| 366 |
+
pytorch_version = Version(torch.__version__)
|
| 367 |
+
pytorch_architecture = descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper)
|
| 368 |
+
|
| 369 |
+
# Package model
|
| 370 |
+
my_model_descr = package_to_bioimageio(
|
| 371 |
+
path_pretrained_model,
|
| 372 |
+
path_save_trace,
|
| 373 |
+
path_readme,
|
| 374 |
+
list_path_cover_images,
|
| 375 |
+
descr_input,
|
| 376 |
+
descr_output,
|
| 377 |
+
descr_output_downsampled_tensors,
|
| 378 |
+
descr_output_style_tensor,
|
| 379 |
+
pytorch_version,
|
| 380 |
+
pytorch_architecture,
|
| 381 |
+
args.model_id,
|
| 382 |
+
args.model_icon,
|
| 383 |
+
args.model_version,
|
| 384 |
+
args.model_name,
|
| 385 |
+
args.model_documentation,
|
| 386 |
+
model_authors,
|
| 387 |
+
model_cite,
|
| 388 |
+
args.model_tags,
|
| 389 |
+
args.model_license,
|
| 390 |
+
args.model_repo,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Test model
|
| 394 |
+
summary = test_model(my_model_descr, weight_format="pytorch_state_dict")
|
| 395 |
+
summary.display()
|
| 396 |
+
summary = test_model(my_model_descr, weight_format="torchscript")
|
| 397 |
+
summary.display()
|
| 398 |
+
|
| 399 |
+
# Save BioImage.IO package
|
| 400 |
+
package_path = save_bioimageio_package(my_model_descr, output_path=Path(path_bioimageio_package))
|
| 401 |
+
print("package path:", package_path)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
if __name__ == "__main__":
|
| 405 |
+
main()
|
models/seg_post_model/cellpose/gui/gui.py
ADDED
|
@@ -0,0 +1,2007 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys, os, pathlib, warnings, datetime, time, copy
|
| 6 |
+
|
| 7 |
+
from qtpy import QtGui, QtCore
|
| 8 |
+
from superqt import QRangeSlider, QCollapsible
|
| 9 |
+
from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, \
|
| 10 |
+
QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, \
|
| 11 |
+
QLineEdit, QMessageBox, QGroupBox, QMenu, QAction
|
| 12 |
+
import pyqtgraph as pg
|
| 13 |
+
|
| 14 |
+
import numpy as np
|
| 15 |
+
from scipy.stats import mode
|
| 16 |
+
import cv2
|
| 17 |
+
|
| 18 |
+
from . import guiparts, menus, io
|
| 19 |
+
from .. import models, core, dynamics, version, train
|
| 20 |
+
from ..utils import download_url_to_file, masks_to_outlines, diameters
|
| 21 |
+
from ..io import get_image_files, imsave, imread
|
| 22 |
+
from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img
|
| 23 |
+
from ..models import normalize_default
|
| 24 |
+
from ..plot import disk
|
| 25 |
+
|
| 26 |
+
try:
|
| 27 |
+
import matplotlib.pyplot as plt
|
| 28 |
+
MATPLOTLIB = True
|
| 29 |
+
except:
|
| 30 |
+
MATPLOTLIB = False
|
| 31 |
+
|
| 32 |
+
Horizontal = QtCore.Qt.Orientation.Horizontal
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
class Slider(QRangeSlider):
|
| 36 |
+
|
| 37 |
+
def __init__(self, parent, name, color):
|
| 38 |
+
super().__init__(Horizontal)
|
| 39 |
+
self.setEnabled(False)
|
| 40 |
+
self.valueChanged.connect(lambda: self.levelChanged(parent))
|
| 41 |
+
self.name = name
|
| 42 |
+
|
| 43 |
+
self.setStyleSheet(""" QSlider{
|
| 44 |
+
background-color: transparent;
|
| 45 |
+
}
|
| 46 |
+
""")
|
| 47 |
+
self.show()
|
| 48 |
+
|
| 49 |
+
def levelChanged(self, parent):
|
| 50 |
+
parent.level_change(self.name)
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class QHLine(QFrame):
|
| 54 |
+
|
| 55 |
+
def __init__(self):
|
| 56 |
+
super(QHLine, self).__init__()
|
| 57 |
+
self.setFrameShape(QFrame.HLine)
|
| 58 |
+
self.setLineWidth(8)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def make_bwr():
|
| 62 |
+
# make a bwr colormap
|
| 63 |
+
b = np.append(255 * np.ones(128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
|
| 64 |
+
r = np.append(np.linspace(0, 255, 128), 255 * np.ones(128))[:, np.newaxis]
|
| 65 |
+
g = np.append(np.linspace(0, 255, 128),
|
| 66 |
+
np.linspace(0, 255, 128)[::-1])[:, np.newaxis]
|
| 67 |
+
color = np.concatenate((r, g, b), axis=-1).astype(np.uint8)
|
| 68 |
+
bwr = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
|
| 69 |
+
return bwr
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def make_spectral():
|
| 73 |
+
# make spectral colormap
|
| 74 |
+
r = np.array([
|
| 75 |
+
0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80,
|
| 76 |
+
84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 128, 128, 128, 128, 128,
|
| 77 |
+
128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 120, 112, 104, 96, 88,
|
| 78 |
+
80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 79 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 7, 11, 15, 19, 23,
|
| 80 |
+
27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103,
|
| 81 |
+
107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167,
|
| 82 |
+
171, 175, 179, 183, 187, 191, 195, 199, 203, 207, 211, 215, 219, 223, 227, 231,
|
| 83 |
+
235, 239, 243, 247, 251, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 84 |
+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 85 |
+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 86 |
+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 87 |
+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 88 |
+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 89 |
+
255, 255, 255, 255, 255
|
| 90 |
+
])
|
| 91 |
+
g = np.array([
|
| 92 |
+
0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 5, 4, 4, 3, 3,
|
| 93 |
+
2, 2, 1, 1, 0, 0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111,
|
| 94 |
+
119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239,
|
| 95 |
+
247, 255, 247, 239, 231, 223, 215, 207, 199, 191, 183, 175, 167, 159, 151, 143,
|
| 96 |
+
135, 128, 129, 131, 132, 134, 135, 137, 139, 140, 142, 143, 145, 147, 148, 150,
|
| 97 |
+
151, 153, 154, 156, 158, 159, 161, 162, 164, 166, 167, 169, 170, 172, 174, 175,
|
| 98 |
+
177, 178, 180, 181, 183, 185, 186, 188, 189, 191, 193, 194, 196, 197, 199, 201,
|
| 99 |
+
202, 204, 205, 207, 208, 210, 212, 213, 215, 216, 218, 220, 221, 223, 224, 226,
|
| 100 |
+
228, 229, 231, 232, 234, 235, 237, 239, 240, 242, 243, 245, 247, 248, 250, 251,
|
| 101 |
+
253, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199,
|
| 102 |
+
195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135,
|
| 103 |
+
131, 127, 123, 119, 115, 111, 107, 103, 99, 95, 91, 87, 83, 79, 75, 71, 67, 63,
|
| 104 |
+
59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 0, 8, 16, 24, 32, 41,
|
| 105 |
+
49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180,
|
| 106 |
+
189, 197, 205, 213, 222, 230, 238, 246, 254
|
| 107 |
+
])
|
| 108 |
+
b = np.array([
|
| 109 |
+
0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143,
|
| 110 |
+
151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 255, 255,
|
| 111 |
+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255,
|
| 112 |
+
255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 251, 247,
|
| 113 |
+
243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183,
|
| 114 |
+
179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 128, 126, 124,
|
| 115 |
+
122, 120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100, 98, 96, 94, 92, 90,
|
| 116 |
+
88, 86, 84, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50,
|
| 117 |
+
48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10,
|
| 118 |
+
8, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 119 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
|
| 120 |
+
0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74,
|
| 121 |
+
82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205,
|
| 122 |
+
213, 222, 230, 238, 246, 254
|
| 123 |
+
])
|
| 124 |
+
color = (np.vstack((r, g, b)).T).astype(np.uint8)
|
| 125 |
+
spectral = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
|
| 126 |
+
return spectral
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def make_cmap(cm=0):
|
| 130 |
+
# make a single channel colormap
|
| 131 |
+
r = np.arange(0, 256)
|
| 132 |
+
color = np.zeros((256, 3))
|
| 133 |
+
color[:, cm] = r
|
| 134 |
+
color = color.astype(np.uint8)
|
| 135 |
+
cmap = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color)
|
| 136 |
+
return cmap
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def run(image=None):
|
| 140 |
+
from ..io import logger_setup
|
| 141 |
+
logger, log_file = logger_setup()
|
| 142 |
+
# Always start by initializing Qt (only once per application)
|
| 143 |
+
warnings.filterwarnings("ignore")
|
| 144 |
+
app = QApplication(sys.argv)
|
| 145 |
+
icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
|
| 146 |
+
guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png")
|
| 147 |
+
if not icon_path.is_file():
|
| 148 |
+
cp_dir = pathlib.Path.home().joinpath(".cellpose")
|
| 149 |
+
cp_dir.mkdir(exist_ok=True)
|
| 150 |
+
print("downloading logo")
|
| 151 |
+
download_url_to_file(
|
| 152 |
+
"https://www.cellpose.org/static/images/cellpose_transparent.png",
|
| 153 |
+
icon_path, progress=True)
|
| 154 |
+
if not guip_path.is_file():
|
| 155 |
+
print("downloading help window image")
|
| 156 |
+
download_url_to_file("https://www.cellpose.org/static/images/cellposeSAM_gui.png",
|
| 157 |
+
guip_path, progress=True)
|
| 158 |
+
icon_path = str(icon_path.resolve())
|
| 159 |
+
app_icon = QtGui.QIcon()
|
| 160 |
+
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
|
| 161 |
+
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
|
| 162 |
+
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
|
| 163 |
+
app_icon.addFile(icon_path, QtCore.QSize(48, 48))
|
| 164 |
+
app_icon.addFile(icon_path, QtCore.QSize(64, 64))
|
| 165 |
+
app_icon.addFile(icon_path, QtCore.QSize(256, 256))
|
| 166 |
+
app.setWindowIcon(app_icon)
|
| 167 |
+
app.setStyle("Fusion")
|
| 168 |
+
app.setPalette(guiparts.DarkPalette())
|
| 169 |
+
MainW(image=image, logger=logger)
|
| 170 |
+
ret = app.exec_()
|
| 171 |
+
sys.exit(ret)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
class MainW(QMainWindow):
|
| 175 |
+
|
| 176 |
+
def __init__(self, image=None, logger=None):
|
| 177 |
+
super(MainW, self).__init__()
|
| 178 |
+
|
| 179 |
+
self.logger = logger
|
| 180 |
+
pg.setConfigOptions(imageAxisOrder="row-major")
|
| 181 |
+
self.setGeometry(50, 50, 1200, 1000)
|
| 182 |
+
self.setWindowTitle(f"cellpose v{version}")
|
| 183 |
+
self.cp_path = os.path.dirname(os.path.realpath(__file__))
|
| 184 |
+
app_icon = QtGui.QIcon()
|
| 185 |
+
icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
|
| 186 |
+
icon_path = str(icon_path.resolve())
|
| 187 |
+
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
|
| 188 |
+
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
|
| 189 |
+
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
|
| 190 |
+
app_icon.addFile(icon_path, QtCore.QSize(48, 48))
|
| 191 |
+
app_icon.addFile(icon_path, QtCore.QSize(64, 64))
|
| 192 |
+
app_icon.addFile(icon_path, QtCore.QSize(256, 256))
|
| 193 |
+
self.setWindowIcon(app_icon)
|
| 194 |
+
# rgb(150,255,150)
|
| 195 |
+
self.setStyleSheet(guiparts.stylesheet())
|
| 196 |
+
|
| 197 |
+
menus.mainmenu(self)
|
| 198 |
+
menus.editmenu(self)
|
| 199 |
+
menus.modelmenu(self)
|
| 200 |
+
menus.helpmenu(self)
|
| 201 |
+
|
| 202 |
+
self.stylePressed = """QPushButton {Text-align: center;
|
| 203 |
+
background-color: rgb(150,50,150);
|
| 204 |
+
border-color: white;
|
| 205 |
+
color:white;}
|
| 206 |
+
QToolTip {
|
| 207 |
+
background-color: black;
|
| 208 |
+
color: white;
|
| 209 |
+
border: black solid 1px
|
| 210 |
+
}"""
|
| 211 |
+
self.styleUnpressed = """QPushButton {Text-align: center;
|
| 212 |
+
background-color: rgb(50,50,50);
|
| 213 |
+
border-color: white;
|
| 214 |
+
color:white;}
|
| 215 |
+
QToolTip {
|
| 216 |
+
background-color: black;
|
| 217 |
+
color: white;
|
| 218 |
+
border: black solid 1px
|
| 219 |
+
}"""
|
| 220 |
+
self.loaded = False
|
| 221 |
+
|
| 222 |
+
# ---- MAIN WIDGET LAYOUT ---- #
|
| 223 |
+
self.cwidget = QWidget(self)
|
| 224 |
+
self.lmain = QGridLayout()
|
| 225 |
+
self.cwidget.setLayout(self.lmain)
|
| 226 |
+
self.setCentralWidget(self.cwidget)
|
| 227 |
+
self.lmain.setVerticalSpacing(0)
|
| 228 |
+
self.lmain.setContentsMargins(0, 0, 0, 10)
|
| 229 |
+
|
| 230 |
+
self.imask = 0
|
| 231 |
+
self.scrollarea = QScrollArea()
|
| 232 |
+
self.scrollarea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn)
|
| 233 |
+
self.scrollarea.setStyleSheet("""QScrollArea { border: none }""")
|
| 234 |
+
self.scrollarea.setWidgetResizable(True)
|
| 235 |
+
self.swidget = QWidget(self)
|
| 236 |
+
self.scrollarea.setWidget(self.swidget)
|
| 237 |
+
self.l0 = QGridLayout()
|
| 238 |
+
self.swidget.setLayout(self.l0)
|
| 239 |
+
b = self.make_buttons()
|
| 240 |
+
self.lmain.addWidget(self.scrollarea, 0, 0, 39, 9)
|
| 241 |
+
|
| 242 |
+
# ---- drawing area ---- #
|
| 243 |
+
self.win = pg.GraphicsLayoutWidget()
|
| 244 |
+
|
| 245 |
+
self.lmain.addWidget(self.win, 0, 9, 40, 30)
|
| 246 |
+
|
| 247 |
+
self.win.scene().sigMouseClicked.connect(self.plot_clicked)
|
| 248 |
+
self.win.scene().sigMouseMoved.connect(self.mouse_moved)
|
| 249 |
+
self.make_viewbox()
|
| 250 |
+
self.lmain.setColumnStretch(10, 1)
|
| 251 |
+
bwrmap = make_bwr()
|
| 252 |
+
self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False)
|
| 253 |
+
self.cmap = []
|
| 254 |
+
# spectral colormap
|
| 255 |
+
self.cmap.append(make_spectral().getLookupTable(start=0.0, stop=255.0,
|
| 256 |
+
alpha=False))
|
| 257 |
+
# single channel colormaps
|
| 258 |
+
for i in range(3):
|
| 259 |
+
self.cmap.append(
|
| 260 |
+
make_cmap(i).getLookupTable(start=0.0, stop=255.0, alpha=False))
|
| 261 |
+
|
| 262 |
+
if MATPLOTLIB:
|
| 263 |
+
self.colormap = (plt.get_cmap("gist_ncar")(np.linspace(0.0, .9, 1000000)) *
|
| 264 |
+
255).astype(np.uint8)
|
| 265 |
+
np.random.seed(42) # make colors stable
|
| 266 |
+
self.colormap = self.colormap[np.random.permutation(1000000)]
|
| 267 |
+
else:
|
| 268 |
+
np.random.seed(42) # make colors stable
|
| 269 |
+
self.colormap = ((np.random.rand(1000000, 3) * 0.8 + 0.1) * 255).astype(
|
| 270 |
+
np.uint8)
|
| 271 |
+
self.NZ = 1
|
| 272 |
+
self.restore = None
|
| 273 |
+
self.ratio = 1.
|
| 274 |
+
self.reset()
|
| 275 |
+
|
| 276 |
+
# This needs to go after .reset() is called to get state fully set up:
|
| 277 |
+
self.autobtn.checkStateChanged.connect(self.compute_saturation_if_checked)
|
| 278 |
+
|
| 279 |
+
self.load_3D = False
|
| 280 |
+
|
| 281 |
+
# if called with image, load it
|
| 282 |
+
if image is not None:
|
| 283 |
+
self.filename = image
|
| 284 |
+
io._load_image(self, self.filename)
|
| 285 |
+
|
| 286 |
+
# training settings
|
| 287 |
+
d = datetime.datetime.now()
|
| 288 |
+
self.training_params = {
|
| 289 |
+
"model_index": 0,
|
| 290 |
+
"learning_rate": 1e-5,
|
| 291 |
+
"weight_decay": 0.1,
|
| 292 |
+
"n_epochs": 100,
|
| 293 |
+
"model_name": "cpsam" + d.strftime("_%Y%m%d_%H%M%S"),
|
| 294 |
+
}
|
| 295 |
+
|
| 296 |
+
self.stitch_threshold = 0.
|
| 297 |
+
self.flow3D_smooth = 0.
|
| 298 |
+
self.anisotropy = 1.
|
| 299 |
+
self.min_size = 15
|
| 300 |
+
|
| 301 |
+
self.setAcceptDrops(True)
|
| 302 |
+
self.win.show()
|
| 303 |
+
self.show()
|
| 304 |
+
|
| 305 |
+
def help_window(self):
|
| 306 |
+
HW = guiparts.HelpWindow(self)
|
| 307 |
+
HW.show()
|
| 308 |
+
|
| 309 |
+
def train_help_window(self):
|
| 310 |
+
THW = guiparts.TrainHelpWindow(self)
|
| 311 |
+
THW.show()
|
| 312 |
+
|
| 313 |
+
def gui_window(self):
|
| 314 |
+
EG = guiparts.ExampleGUI(self)
|
| 315 |
+
EG.show()
|
| 316 |
+
|
| 317 |
+
def make_buttons(self):
|
| 318 |
+
self.boldfont = QtGui.QFont("Arial", 11, QtGui.QFont.Bold)
|
| 319 |
+
self.boldmedfont = QtGui.QFont("Arial", 9, QtGui.QFont.Bold)
|
| 320 |
+
self.medfont = QtGui.QFont("Arial", 9)
|
| 321 |
+
self.smallfont = QtGui.QFont("Arial", 8)
|
| 322 |
+
|
| 323 |
+
b = 0
|
| 324 |
+
self.satBox = QGroupBox("Views")
|
| 325 |
+
self.satBox.setFont(self.boldfont)
|
| 326 |
+
self.satBoxG = QGridLayout()
|
| 327 |
+
self.satBox.setLayout(self.satBoxG)
|
| 328 |
+
self.l0.addWidget(self.satBox, b, 0, 1, 9)
|
| 329 |
+
|
| 330 |
+
widget_row = 0
|
| 331 |
+
self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob
|
| 332 |
+
self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B
|
| 333 |
+
self.RGBDropDown = QComboBox()
|
| 334 |
+
self.RGBDropDown.addItems(
|
| 335 |
+
["RGB", "red=R", "green=G", "blue=B", "gray", "spectral"])
|
| 336 |
+
self.RGBDropDown.setFont(self.medfont)
|
| 337 |
+
self.RGBDropDown.currentIndexChanged.connect(self.color_choose)
|
| 338 |
+
self.satBoxG.addWidget(self.RGBDropDown, widget_row, 0, 1, 3)
|
| 339 |
+
|
| 340 |
+
label = QLabel("<p>[↑ / ↓ or W/S]</p>")
|
| 341 |
+
label.setFont(self.smallfont)
|
| 342 |
+
self.satBoxG.addWidget(label, widget_row, 3, 1, 3)
|
| 343 |
+
label = QLabel("[R / G / B \n toggles color ]")
|
| 344 |
+
label.setFont(self.smallfont)
|
| 345 |
+
self.satBoxG.addWidget(label, widget_row, 6, 1, 3)
|
| 346 |
+
|
| 347 |
+
widget_row += 1
|
| 348 |
+
self.ViewDropDown = QComboBox()
|
| 349 |
+
self.ViewDropDown.addItems(["image", "gradXY", "cellprob", "restored"])
|
| 350 |
+
self.ViewDropDown.setFont(self.medfont)
|
| 351 |
+
self.ViewDropDown.model().item(3).setEnabled(False)
|
| 352 |
+
self.ViewDropDown.currentIndexChanged.connect(self.update_plot)
|
| 353 |
+
self.satBoxG.addWidget(self.ViewDropDown, widget_row, 0, 2, 3)
|
| 354 |
+
|
| 355 |
+
label = QLabel("[pageup / pagedown]")
|
| 356 |
+
label.setFont(self.smallfont)
|
| 357 |
+
self.satBoxG.addWidget(label, widget_row, 3, 1, 5)
|
| 358 |
+
|
| 359 |
+
widget_row += 2
|
| 360 |
+
label = QLabel("")
|
| 361 |
+
label.setToolTip(
|
| 362 |
+
"NOTE: manually changing the saturation bars does not affect normalization in segmentation"
|
| 363 |
+
)
|
| 364 |
+
self.satBoxG.addWidget(label, widget_row, 0, 1, 5)
|
| 365 |
+
|
| 366 |
+
self.autobtn = QCheckBox("auto-adjust saturation")
|
| 367 |
+
self.autobtn.setToolTip("sets scale-bars as normalized for segmentation")
|
| 368 |
+
self.autobtn.setFont(self.medfont)
|
| 369 |
+
self.autobtn.setChecked(True)
|
| 370 |
+
self.satBoxG.addWidget(self.autobtn, widget_row, 1, 1, 8)
|
| 371 |
+
|
| 372 |
+
widget_row += 1
|
| 373 |
+
self.sliders = []
|
| 374 |
+
colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [100, 100, 100]]
|
| 375 |
+
colornames = ["red", "Chartreuse", "DodgerBlue"]
|
| 376 |
+
names = ["red", "green", "blue"]
|
| 377 |
+
for r in range(3):
|
| 378 |
+
widget_row += 1
|
| 379 |
+
if r == 0:
|
| 380 |
+
label = QLabel('<font color="gray">gray/</font><br>red')
|
| 381 |
+
else:
|
| 382 |
+
label = QLabel(names[r] + ":")
|
| 383 |
+
label.setStyleSheet(f"color: {colornames[r]}")
|
| 384 |
+
label.setFont(self.boldmedfont)
|
| 385 |
+
self.satBoxG.addWidget(label, widget_row, 0, 1, 2)
|
| 386 |
+
self.sliders.append(Slider(self, names[r], colors[r]))
|
| 387 |
+
self.sliders[-1].setMinimum(-.1)
|
| 388 |
+
self.sliders[-1].setMaximum(255.1)
|
| 389 |
+
self.sliders[-1].setValue([0, 255])
|
| 390 |
+
self.sliders[-1].setToolTip(
|
| 391 |
+
"NOTE: manually changing the saturation bars does not affect normalization in segmentation"
|
| 392 |
+
)
|
| 393 |
+
self.satBoxG.addWidget(self.sliders[-1], widget_row, 2, 1, 7)
|
| 394 |
+
|
| 395 |
+
b += 1
|
| 396 |
+
self.drawBox = QGroupBox("Drawing")
|
| 397 |
+
self.drawBox.setFont(self.boldfont)
|
| 398 |
+
self.drawBoxG = QGridLayout()
|
| 399 |
+
self.drawBox.setLayout(self.drawBoxG)
|
| 400 |
+
self.l0.addWidget(self.drawBox, b, 0, 1, 9)
|
| 401 |
+
self.autosave = True
|
| 402 |
+
|
| 403 |
+
widget_row = 0
|
| 404 |
+
self.brush_size = 3
|
| 405 |
+
self.BrushChoose = QComboBox()
|
| 406 |
+
self.BrushChoose.addItems(["1", "3", "5", "7", "9"])
|
| 407 |
+
self.BrushChoose.currentIndexChanged.connect(self.brush_choose)
|
| 408 |
+
self.BrushChoose.setFixedWidth(40)
|
| 409 |
+
self.BrushChoose.setFont(self.medfont)
|
| 410 |
+
self.drawBoxG.addWidget(self.BrushChoose, widget_row, 3, 1, 2)
|
| 411 |
+
label = QLabel("brush size:")
|
| 412 |
+
label.setFont(self.medfont)
|
| 413 |
+
self.drawBoxG.addWidget(label, widget_row, 0, 1, 3)
|
| 414 |
+
|
| 415 |
+
widget_row += 1
|
| 416 |
+
# turn off masks
|
| 417 |
+
self.layer_off = False
|
| 418 |
+
self.masksOn = True
|
| 419 |
+
self.MCheckBox = QCheckBox("MASKS ON [X]")
|
| 420 |
+
self.MCheckBox.setFont(self.medfont)
|
| 421 |
+
self.MCheckBox.setChecked(True)
|
| 422 |
+
self.MCheckBox.toggled.connect(self.toggle_masks)
|
| 423 |
+
self.drawBoxG.addWidget(self.MCheckBox, widget_row, 0, 1, 5)
|
| 424 |
+
|
| 425 |
+
widget_row += 1
|
| 426 |
+
# turn off outlines
|
| 427 |
+
self.outlinesOn = False # turn off by default
|
| 428 |
+
self.OCheckBox = QCheckBox("outlines on [Z]")
|
| 429 |
+
self.OCheckBox.setFont(self.medfont)
|
| 430 |
+
self.drawBoxG.addWidget(self.OCheckBox, widget_row, 0, 1, 5)
|
| 431 |
+
self.OCheckBox.setChecked(False)
|
| 432 |
+
self.OCheckBox.toggled.connect(self.toggle_masks)
|
| 433 |
+
|
| 434 |
+
widget_row += 1
|
| 435 |
+
self.SCheckBox = QCheckBox("single stroke")
|
| 436 |
+
self.SCheckBox.setFont(self.medfont)
|
| 437 |
+
self.SCheckBox.setChecked(True)
|
| 438 |
+
self.SCheckBox.toggled.connect(self.autosave_on)
|
| 439 |
+
self.SCheckBox.setEnabled(True)
|
| 440 |
+
self.drawBoxG.addWidget(self.SCheckBox, widget_row, 0, 1, 5)
|
| 441 |
+
|
| 442 |
+
# buttons for deleting multiple cells
|
| 443 |
+
self.deleteBox = QGroupBox("delete multiple ROIs")
|
| 444 |
+
self.deleteBox.setStyleSheet("color: rgb(200, 200, 200)")
|
| 445 |
+
self.deleteBox.setFont(self.medfont)
|
| 446 |
+
self.deleteBoxG = QGridLayout()
|
| 447 |
+
self.deleteBox.setLayout(self.deleteBoxG)
|
| 448 |
+
self.drawBoxG.addWidget(self.deleteBox, 0, 5, 4, 4)
|
| 449 |
+
self.MakeDeletionRegionButton = QPushButton("region-select")
|
| 450 |
+
self.MakeDeletionRegionButton.clicked.connect(self.remove_region_cells)
|
| 451 |
+
self.deleteBoxG.addWidget(self.MakeDeletionRegionButton, 0, 0, 1, 4)
|
| 452 |
+
self.MakeDeletionRegionButton.setFont(self.smallfont)
|
| 453 |
+
self.MakeDeletionRegionButton.setFixedWidth(70)
|
| 454 |
+
self.DeleteMultipleROIButton = QPushButton("click-select")
|
| 455 |
+
self.DeleteMultipleROIButton.clicked.connect(self.delete_multiple_cells)
|
| 456 |
+
self.deleteBoxG.addWidget(self.DeleteMultipleROIButton, 1, 0, 1, 4)
|
| 457 |
+
self.DeleteMultipleROIButton.setFont(self.smallfont)
|
| 458 |
+
self.DeleteMultipleROIButton.setFixedWidth(70)
|
| 459 |
+
self.DoneDeleteMultipleROIButton = QPushButton("done")
|
| 460 |
+
self.DoneDeleteMultipleROIButton.clicked.connect(
|
| 461 |
+
self.done_remove_multiple_cells)
|
| 462 |
+
self.deleteBoxG.addWidget(self.DoneDeleteMultipleROIButton, 2, 0, 1, 2)
|
| 463 |
+
self.DoneDeleteMultipleROIButton.setFont(self.smallfont)
|
| 464 |
+
self.DoneDeleteMultipleROIButton.setFixedWidth(35)
|
| 465 |
+
self.CancelDeleteMultipleROIButton = QPushButton("cancel")
|
| 466 |
+
self.CancelDeleteMultipleROIButton.clicked.connect(self.cancel_remove_multiple)
|
| 467 |
+
self.deleteBoxG.addWidget(self.CancelDeleteMultipleROIButton, 2, 2, 1, 2)
|
| 468 |
+
self.CancelDeleteMultipleROIButton.setFont(self.smallfont)
|
| 469 |
+
self.CancelDeleteMultipleROIButton.setFixedWidth(35)
|
| 470 |
+
|
| 471 |
+
b += 1
|
| 472 |
+
widget_row = 0
|
| 473 |
+
self.segBox = QGroupBox("Segmentation")
|
| 474 |
+
self.segBoxG = QGridLayout()
|
| 475 |
+
self.segBox.setLayout(self.segBoxG)
|
| 476 |
+
self.l0.addWidget(self.segBox, b, 0, 1, 9)
|
| 477 |
+
self.segBox.setFont(self.boldfont)
|
| 478 |
+
|
| 479 |
+
widget_row += 1
|
| 480 |
+
|
| 481 |
+
# use GPU
|
| 482 |
+
self.useGPU = QCheckBox("use GPU")
|
| 483 |
+
self.useGPU.setToolTip(
|
| 484 |
+
"if you have specially installed the <i>cuda</i> version of torch, then you can activate this"
|
| 485 |
+
)
|
| 486 |
+
self.useGPU.setFont(self.medfont)
|
| 487 |
+
self.check_gpu()
|
| 488 |
+
self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3)
|
| 489 |
+
|
| 490 |
+
# compute segmentation with general models
|
| 491 |
+
self.net_text = ["run CPSAM"]
|
| 492 |
+
nett = ["cellpose super-generalist model"]
|
| 493 |
+
|
| 494 |
+
self.StyleButtons = []
|
| 495 |
+
jj = 4
|
| 496 |
+
for j in range(len(self.net_text)):
|
| 497 |
+
self.StyleButtons.append(
|
| 498 |
+
guiparts.ModelButton(self, self.net_text[j], self.net_text[j]))
|
| 499 |
+
w = 5
|
| 500 |
+
self.segBoxG.addWidget(self.StyleButtons[-1], widget_row, jj, 1, w)
|
| 501 |
+
jj += w
|
| 502 |
+
self.StyleButtons[-1].setToolTip(nett[j])
|
| 503 |
+
|
| 504 |
+
widget_row += 1
|
| 505 |
+
self.ncells = guiparts.ObservableVariable(0)
|
| 506 |
+
self.roi_count = QLabel()
|
| 507 |
+
self.roi_count.setFont(self.boldfont)
|
| 508 |
+
self.roi_count.setAlignment(QtCore.Qt.AlignLeft)
|
| 509 |
+
self.ncells.valueChanged.connect(
|
| 510 |
+
lambda n: self.roi_count.setText(f'{str(n)} ROIs')
|
| 511 |
+
)
|
| 512 |
+
|
| 513 |
+
self.segBoxG.addWidget(self.roi_count, widget_row, 0, 1, 4)
|
| 514 |
+
|
| 515 |
+
self.progress = QProgressBar(self)
|
| 516 |
+
self.segBoxG.addWidget(self.progress, widget_row, 4, 1, 5)
|
| 517 |
+
|
| 518 |
+
widget_row += 1
|
| 519 |
+
|
| 520 |
+
############################### Segmentation settings ###############################
|
| 521 |
+
self.additional_seg_settings_qcollapsible = QCollapsible("additional settings")
|
| 522 |
+
self.additional_seg_settings_qcollapsible.setFont(self.medfont)
|
| 523 |
+
self.additional_seg_settings_qcollapsible._toggle_btn.setFont(self.medfont)
|
| 524 |
+
self.segmentation_settings = guiparts.SegmentationSettings(self.medfont)
|
| 525 |
+
self.additional_seg_settings_qcollapsible.setContent(self.segmentation_settings)
|
| 526 |
+
self.segBoxG.addWidget(self.additional_seg_settings_qcollapsible, widget_row, 0, 1, 9)
|
| 527 |
+
|
| 528 |
+
# connect edits to image processing steps:
|
| 529 |
+
self.segmentation_settings.diameter_box.editingFinished.connect(self.update_scale)
|
| 530 |
+
self.segmentation_settings.flow_threshold_box.returnPressed.connect(self.compute_cprob)
|
| 531 |
+
self.segmentation_settings.cellprob_threshold_box.returnPressed.connect(self.compute_cprob)
|
| 532 |
+
self.segmentation_settings.niter_box.returnPressed.connect(self.compute_cprob)
|
| 533 |
+
|
| 534 |
+
# Needed to do this for the drop down to not be open on startup
|
| 535 |
+
self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(True)
|
| 536 |
+
self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(False)
|
| 537 |
+
|
| 538 |
+
b += 1
|
| 539 |
+
self.modelBox = QGroupBox("user-trained models")
|
| 540 |
+
self.modelBoxG = QGridLayout()
|
| 541 |
+
self.modelBox.setLayout(self.modelBoxG)
|
| 542 |
+
self.l0.addWidget(self.modelBox, b, 0, 1, 9)
|
| 543 |
+
self.modelBox.setFont(self.boldfont)
|
| 544 |
+
# choose models
|
| 545 |
+
self.ModelChooseC = QComboBox()
|
| 546 |
+
self.ModelChooseC.setFont(self.medfont)
|
| 547 |
+
current_index = 0
|
| 548 |
+
self.ModelChooseC.addItems(["custom models"])
|
| 549 |
+
if len(self.model_strings) > 0:
|
| 550 |
+
self.ModelChooseC.addItems(self.model_strings)
|
| 551 |
+
self.ModelChooseC.setFixedWidth(175)
|
| 552 |
+
self.ModelChooseC.setCurrentIndex(current_index)
|
| 553 |
+
tipstr = 'add or train your own models in the "Models" file menu and choose model here'
|
| 554 |
+
self.ModelChooseC.setToolTip(tipstr)
|
| 555 |
+
self.ModelChooseC.activated.connect(lambda: self.model_choose(custom=True))
|
| 556 |
+
self.modelBoxG.addWidget(self.ModelChooseC, widget_row, 0, 1, 8)
|
| 557 |
+
|
| 558 |
+
# compute segmentation w/ custom model
|
| 559 |
+
self.ModelButtonC = QPushButton(u"run")
|
| 560 |
+
self.ModelButtonC.setFont(self.medfont)
|
| 561 |
+
self.ModelButtonC.setFixedWidth(35)
|
| 562 |
+
self.ModelButtonC.clicked.connect(
|
| 563 |
+
lambda: self.compute_segmentation(custom=True))
|
| 564 |
+
self.modelBoxG.addWidget(self.ModelButtonC, widget_row, 8, 1, 1)
|
| 565 |
+
self.ModelButtonC.setEnabled(False)
|
| 566 |
+
|
| 567 |
+
|
| 568 |
+
b += 1
|
| 569 |
+
self.filterBox = QGroupBox("Image filtering")
|
| 570 |
+
self.filterBox.setFont(self.boldfont)
|
| 571 |
+
self.filterBox_grid_layout = QGridLayout()
|
| 572 |
+
self.filterBox.setLayout(self.filterBox_grid_layout)
|
| 573 |
+
self.l0.addWidget(self.filterBox, b, 0, 1, 9)
|
| 574 |
+
|
| 575 |
+
widget_row = 0
|
| 576 |
+
|
| 577 |
+
# Filtering
|
| 578 |
+
self.FilterButtons = []
|
| 579 |
+
nett = [
|
| 580 |
+
"clear restore/filter",
|
| 581 |
+
"filter image (settings below)",
|
| 582 |
+
]
|
| 583 |
+
self.filter_text = ["none",
|
| 584 |
+
"filter",
|
| 585 |
+
]
|
| 586 |
+
self.restore = None
|
| 587 |
+
self.ratio = 1.
|
| 588 |
+
jj = 0
|
| 589 |
+
w = 3
|
| 590 |
+
for j in range(len(self.filter_text)):
|
| 591 |
+
self.FilterButtons.append(
|
| 592 |
+
guiparts.FilterButton(self, self.filter_text[j]))
|
| 593 |
+
self.filterBox_grid_layout.addWidget(self.FilterButtons[-1], widget_row, jj, 1, w)
|
| 594 |
+
self.FilterButtons[-1].setFixedWidth(75)
|
| 595 |
+
self.FilterButtons[-1].setToolTip(nett[j])
|
| 596 |
+
self.FilterButtons[-1].setFont(self.medfont)
|
| 597 |
+
widget_row += 1 if j%2==1 else 0
|
| 598 |
+
jj = 0 if j%2==1 else jj + w
|
| 599 |
+
|
| 600 |
+
self.save_norm = QCheckBox("save restored/filtered image")
|
| 601 |
+
self.save_norm.setFont(self.medfont)
|
| 602 |
+
self.save_norm.setToolTip("save restored/filtered image in _seg.npy file")
|
| 603 |
+
self.save_norm.setChecked(True)
|
| 604 |
+
|
| 605 |
+
widget_row += 2
|
| 606 |
+
|
| 607 |
+
self.filtBox = QCollapsible("custom filter settings")
|
| 608 |
+
self.filtBox._toggle_btn.setFont(self.medfont)
|
| 609 |
+
self.filtBoxG = QGridLayout()
|
| 610 |
+
_content = QWidget()
|
| 611 |
+
_content.setLayout(self.filtBoxG)
|
| 612 |
+
_content.setMaximumHeight(0)
|
| 613 |
+
_content.setMinimumHeight(0)
|
| 614 |
+
self.filtBox.setContent(_content)
|
| 615 |
+
self.filterBox_grid_layout.addWidget(self.filtBox, widget_row, 0, 1, 9)
|
| 616 |
+
|
| 617 |
+
self.filt_vals = [0., 0., 0., 0.]
|
| 618 |
+
self.filt_edits = []
|
| 619 |
+
labels = [
|
| 620 |
+
"sharpen\nradius", "smooth\nradius", "tile_norm\nblocksize",
|
| 621 |
+
"tile_norm\nsmooth3D"
|
| 622 |
+
]
|
| 623 |
+
tooltips = [
|
| 624 |
+
"set size of surround-subtraction filter for sharpening image",
|
| 625 |
+
"set size of gaussian filter for smoothing image",
|
| 626 |
+
"set size of tiles to use to normalize image",
|
| 627 |
+
"set amount of smoothing of normalization values across planes"
|
| 628 |
+
]
|
| 629 |
+
|
| 630 |
+
for p in range(4):
|
| 631 |
+
label = QLabel(f"{labels[p]}:")
|
| 632 |
+
label.setToolTip(tooltips[p])
|
| 633 |
+
label.setFont(self.medfont)
|
| 634 |
+
self.filtBoxG.addWidget(label, widget_row + p // 2, 4 * (p % 2), 1, 2)
|
| 635 |
+
self.filt_edits.append(QLineEdit())
|
| 636 |
+
self.filt_edits[p].setText(str(self.filt_vals[p]))
|
| 637 |
+
self.filt_edits[p].setFixedWidth(40)
|
| 638 |
+
self.filt_edits[p].setFont(self.medfont)
|
| 639 |
+
self.filtBoxG.addWidget(self.filt_edits[p], widget_row + p // 2, 4 * (p % 2) + 2, 1,
|
| 640 |
+
2)
|
| 641 |
+
self.filt_edits[p].setToolTip(tooltips[p])
|
| 642 |
+
|
| 643 |
+
widget_row += 3
|
| 644 |
+
self.norm3D_cb = QCheckBox("norm3D")
|
| 645 |
+
self.norm3D_cb.setFont(self.medfont)
|
| 646 |
+
self.norm3D_cb.setChecked(True)
|
| 647 |
+
self.norm3D_cb.setToolTip("run same normalization across planes")
|
| 648 |
+
self.filtBoxG.addWidget(self.norm3D_cb, widget_row, 0, 1, 3)
|
| 649 |
+
|
| 650 |
+
|
| 651 |
+
return b
|
| 652 |
+
|
| 653 |
+
def level_change(self, r):
|
| 654 |
+
r = ["red", "green", "blue"].index(r)
|
| 655 |
+
if self.loaded:
|
| 656 |
+
sval = self.sliders[r].value()
|
| 657 |
+
self.saturation[r][self.currentZ] = sval
|
| 658 |
+
if not self.autobtn.isChecked():
|
| 659 |
+
for r in range(3):
|
| 660 |
+
for i in range(len(self.saturation[r])):
|
| 661 |
+
self.saturation[r][i] = self.saturation[r][self.currentZ]
|
| 662 |
+
self.update_plot()
|
| 663 |
+
|
| 664 |
+
def keyPressEvent(self, event):
|
| 665 |
+
if self.loaded:
|
| 666 |
+
if not (event.modifiers() &
|
| 667 |
+
(QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
|
| 668 |
+
QtCore.Qt.AltModifier) or self.in_stroke):
|
| 669 |
+
updated = False
|
| 670 |
+
if len(self.current_point_set) > 0:
|
| 671 |
+
if event.key() == QtCore.Qt.Key_Return:
|
| 672 |
+
self.add_set()
|
| 673 |
+
else:
|
| 674 |
+
nviews = self.ViewDropDown.count() - 1
|
| 675 |
+
nviews += int(
|
| 676 |
+
self.ViewDropDown.model().item(self.ViewDropDown.count() -
|
| 677 |
+
1).isEnabled())
|
| 678 |
+
if event.key() == QtCore.Qt.Key_X:
|
| 679 |
+
self.MCheckBox.toggle()
|
| 680 |
+
if event.key() == QtCore.Qt.Key_Z:
|
| 681 |
+
self.OCheckBox.toggle()
|
| 682 |
+
if event.key() == QtCore.Qt.Key_Left or event.key(
|
| 683 |
+
) == QtCore.Qt.Key_A:
|
| 684 |
+
self.get_prev_image()
|
| 685 |
+
elif event.key() == QtCore.Qt.Key_Right or event.key(
|
| 686 |
+
) == QtCore.Qt.Key_D:
|
| 687 |
+
self.get_next_image()
|
| 688 |
+
elif event.key() == QtCore.Qt.Key_PageDown:
|
| 689 |
+
self.view = (self.view + 1) % (nviews)
|
| 690 |
+
self.ViewDropDown.setCurrentIndex(self.view)
|
| 691 |
+
elif event.key() == QtCore.Qt.Key_PageUp:
|
| 692 |
+
self.view = (self.view - 1) % (nviews)
|
| 693 |
+
self.ViewDropDown.setCurrentIndex(self.view)
|
| 694 |
+
|
| 695 |
+
# can change background or stroke size if cell not finished
|
| 696 |
+
if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
|
| 697 |
+
self.color = (self.color - 1) % (6)
|
| 698 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 699 |
+
elif event.key() == QtCore.Qt.Key_Down or event.key(
|
| 700 |
+
) == QtCore.Qt.Key_S:
|
| 701 |
+
self.color = (self.color + 1) % (6)
|
| 702 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 703 |
+
elif event.key() == QtCore.Qt.Key_R:
|
| 704 |
+
if self.color != 1:
|
| 705 |
+
self.color = 1
|
| 706 |
+
else:
|
| 707 |
+
self.color = 0
|
| 708 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 709 |
+
elif event.key() == QtCore.Qt.Key_G:
|
| 710 |
+
if self.color != 2:
|
| 711 |
+
self.color = 2
|
| 712 |
+
else:
|
| 713 |
+
self.color = 0
|
| 714 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 715 |
+
elif event.key() == QtCore.Qt.Key_B:
|
| 716 |
+
if self.color != 3:
|
| 717 |
+
self.color = 3
|
| 718 |
+
else:
|
| 719 |
+
self.color = 0
|
| 720 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 721 |
+
elif (event.key() == QtCore.Qt.Key_Comma or
|
| 722 |
+
event.key() == QtCore.Qt.Key_Period):
|
| 723 |
+
count = self.BrushChoose.count()
|
| 724 |
+
gci = self.BrushChoose.currentIndex()
|
| 725 |
+
if event.key() == QtCore.Qt.Key_Comma:
|
| 726 |
+
gci = max(0, gci - 1)
|
| 727 |
+
else:
|
| 728 |
+
gci = min(count - 1, gci + 1)
|
| 729 |
+
self.BrushChoose.setCurrentIndex(gci)
|
| 730 |
+
self.brush_choose()
|
| 731 |
+
if not updated:
|
| 732 |
+
self.update_plot()
|
| 733 |
+
if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
|
| 734 |
+
self.p0.keyPressEvent(event)
|
| 735 |
+
|
| 736 |
+
def autosave_on(self):
|
| 737 |
+
if self.SCheckBox.isChecked():
|
| 738 |
+
self.autosave = True
|
| 739 |
+
else:
|
| 740 |
+
self.autosave = False
|
| 741 |
+
|
| 742 |
+
def check_gpu(self, torch=True):
|
| 743 |
+
# also decide whether or not to use torch
|
| 744 |
+
self.useGPU.setChecked(False)
|
| 745 |
+
self.useGPU.setEnabled(False)
|
| 746 |
+
if core.use_gpu(use_torch=True):
|
| 747 |
+
self.useGPU.setEnabled(True)
|
| 748 |
+
self.useGPU.setChecked(True)
|
| 749 |
+
else:
|
| 750 |
+
self.useGPU.setStyleSheet("color: rgb(80,80,80);")
|
| 751 |
+
|
| 752 |
+
|
| 753 |
+
def model_choose(self, custom=False):
|
| 754 |
+
index = self.ModelChooseC.currentIndex(
|
| 755 |
+
) if custom else self.ModelChooseB.currentIndex()
|
| 756 |
+
if index > 0:
|
| 757 |
+
if custom:
|
| 758 |
+
model_name = self.ModelChooseC.currentText()
|
| 759 |
+
else:
|
| 760 |
+
model_name = self.net_names[index - 1]
|
| 761 |
+
print(f"GUI_INFO: selected model {model_name}, loading now")
|
| 762 |
+
self.initialize_model(model_name=model_name, custom=custom)
|
| 763 |
+
|
| 764 |
+
def toggle_scale(self):
|
| 765 |
+
if self.scale_on:
|
| 766 |
+
self.p0.removeItem(self.scale)
|
| 767 |
+
self.scale_on = False
|
| 768 |
+
else:
|
| 769 |
+
self.p0.addItem(self.scale)
|
| 770 |
+
self.scale_on = True
|
| 771 |
+
|
| 772 |
+
def enable_buttons(self):
|
| 773 |
+
if len(self.model_strings) > 0:
|
| 774 |
+
self.ModelButtonC.setEnabled(True)
|
| 775 |
+
for i in range(len(self.StyleButtons)):
|
| 776 |
+
self.StyleButtons[i].setEnabled(True)
|
| 777 |
+
|
| 778 |
+
for i in range(len(self.FilterButtons)):
|
| 779 |
+
self.FilterButtons[i].setEnabled(True)
|
| 780 |
+
if self.load_3D:
|
| 781 |
+
self.FilterButtons[-2].setEnabled(False)
|
| 782 |
+
|
| 783 |
+
self.newmodel.setEnabled(True)
|
| 784 |
+
self.loadMasks.setEnabled(True)
|
| 785 |
+
|
| 786 |
+
for n in range(self.nchan):
|
| 787 |
+
self.sliders[n].setEnabled(True)
|
| 788 |
+
for n in range(self.nchan, 3):
|
| 789 |
+
self.sliders[n].setEnabled(True)
|
| 790 |
+
|
| 791 |
+
self.toggle_mask_ops()
|
| 792 |
+
|
| 793 |
+
self.update_plot()
|
| 794 |
+
self.setWindowTitle(self.filename)
|
| 795 |
+
|
| 796 |
+
def disable_buttons_removeROIs(self):
|
| 797 |
+
if len(self.model_strings) > 0:
|
| 798 |
+
self.ModelButtonC.setEnabled(False)
|
| 799 |
+
for i in range(len(self.StyleButtons)):
|
| 800 |
+
self.StyleButtons[i].setEnabled(False)
|
| 801 |
+
self.newmodel.setEnabled(False)
|
| 802 |
+
self.loadMasks.setEnabled(False)
|
| 803 |
+
self.saveSet.setEnabled(False)
|
| 804 |
+
self.savePNG.setEnabled(False)
|
| 805 |
+
self.saveFlows.setEnabled(False)
|
| 806 |
+
self.saveOutlines.setEnabled(False)
|
| 807 |
+
self.saveROIs.setEnabled(False)
|
| 808 |
+
|
| 809 |
+
self.MakeDeletionRegionButton.setEnabled(False)
|
| 810 |
+
self.DeleteMultipleROIButton.setEnabled(False)
|
| 811 |
+
self.DoneDeleteMultipleROIButton.setEnabled(True)
|
| 812 |
+
self.CancelDeleteMultipleROIButton.setEnabled(True)
|
| 813 |
+
|
| 814 |
+
def toggle_mask_ops(self):
|
| 815 |
+
self.update_layer()
|
| 816 |
+
self.toggle_saving()
|
| 817 |
+
self.toggle_removals()
|
| 818 |
+
|
| 819 |
+
def toggle_saving(self):
|
| 820 |
+
if self.ncells > 0:
|
| 821 |
+
self.saveSet.setEnabled(True)
|
| 822 |
+
self.savePNG.setEnabled(True)
|
| 823 |
+
self.saveFlows.setEnabled(True)
|
| 824 |
+
self.saveOutlines.setEnabled(True)
|
| 825 |
+
self.saveROIs.setEnabled(True)
|
| 826 |
+
else:
|
| 827 |
+
self.saveSet.setEnabled(False)
|
| 828 |
+
self.savePNG.setEnabled(False)
|
| 829 |
+
self.saveFlows.setEnabled(False)
|
| 830 |
+
self.saveOutlines.setEnabled(False)
|
| 831 |
+
self.saveROIs.setEnabled(False)
|
| 832 |
+
|
| 833 |
+
def toggle_removals(self):
|
| 834 |
+
if self.ncells > 0:
|
| 835 |
+
self.ClearButton.setEnabled(True)
|
| 836 |
+
self.remcell.setEnabled(True)
|
| 837 |
+
self.undo.setEnabled(True)
|
| 838 |
+
self.MakeDeletionRegionButton.setEnabled(True)
|
| 839 |
+
self.DeleteMultipleROIButton.setEnabled(True)
|
| 840 |
+
self.DoneDeleteMultipleROIButton.setEnabled(False)
|
| 841 |
+
self.CancelDeleteMultipleROIButton.setEnabled(False)
|
| 842 |
+
else:
|
| 843 |
+
self.ClearButton.setEnabled(False)
|
| 844 |
+
self.remcell.setEnabled(False)
|
| 845 |
+
self.undo.setEnabled(False)
|
| 846 |
+
self.MakeDeletionRegionButton.setEnabled(False)
|
| 847 |
+
self.DeleteMultipleROIButton.setEnabled(False)
|
| 848 |
+
self.DoneDeleteMultipleROIButton.setEnabled(False)
|
| 849 |
+
self.CancelDeleteMultipleROIButton.setEnabled(False)
|
| 850 |
+
|
| 851 |
+
def remove_action(self):
|
| 852 |
+
if self.selected > 0:
|
| 853 |
+
self.remove_cell(self.selected)
|
| 854 |
+
|
| 855 |
+
def undo_action(self):
|
| 856 |
+
if (len(self.strokes) > 0 and self.strokes[-1][0][0] == self.currentZ):
|
| 857 |
+
self.remove_stroke()
|
| 858 |
+
else:
|
| 859 |
+
# remove previous cell
|
| 860 |
+
if self.ncells > 0:
|
| 861 |
+
self.remove_cell(self.ncells.get())
|
| 862 |
+
|
| 863 |
+
def undo_remove_action(self):
|
| 864 |
+
self.undo_remove_cell()
|
| 865 |
+
|
| 866 |
+
def get_files(self):
|
| 867 |
+
folder = os.path.dirname(self.filename)
|
| 868 |
+
mask_filter = "_masks"
|
| 869 |
+
images = get_image_files(folder, mask_filter)
|
| 870 |
+
fnames = [os.path.split(images[k])[-1] for k in range(len(images))]
|
| 871 |
+
f0 = os.path.split(self.filename)[-1]
|
| 872 |
+
idx = np.nonzero(np.array(fnames) == f0)[0][0]
|
| 873 |
+
return images, idx
|
| 874 |
+
|
| 875 |
+
def get_prev_image(self):
|
| 876 |
+
images, idx = self.get_files()
|
| 877 |
+
idx = (idx - 1) % len(images)
|
| 878 |
+
io._load_image(self, filename=images[idx])
|
| 879 |
+
|
| 880 |
+
def get_next_image(self, load_seg=True):
|
| 881 |
+
images, idx = self.get_files()
|
| 882 |
+
idx = (idx + 1) % len(images)
|
| 883 |
+
io._load_image(self, filename=images[idx], load_seg=load_seg)
|
| 884 |
+
|
| 885 |
+
def dragEnterEvent(self, event):
|
| 886 |
+
if event.mimeData().hasUrls():
|
| 887 |
+
event.accept()
|
| 888 |
+
else:
|
| 889 |
+
event.ignore()
|
| 890 |
+
|
| 891 |
+
def dropEvent(self, event):
|
| 892 |
+
files = [u.toLocalFile() for u in event.mimeData().urls()]
|
| 893 |
+
if os.path.splitext(files[0])[-1] == ".npy":
|
| 894 |
+
io._load_seg(self, filename=files[0], load_3D=self.load_3D)
|
| 895 |
+
else:
|
| 896 |
+
io._load_image(self, filename=files[0], load_seg=True, load_3D=self.load_3D)
|
| 897 |
+
|
| 898 |
+
def toggle_masks(self):
|
| 899 |
+
if self.MCheckBox.isChecked():
|
| 900 |
+
self.masksOn = True
|
| 901 |
+
else:
|
| 902 |
+
self.masksOn = False
|
| 903 |
+
if self.OCheckBox.isChecked():
|
| 904 |
+
self.outlinesOn = True
|
| 905 |
+
else:
|
| 906 |
+
self.outlinesOn = False
|
| 907 |
+
if not self.masksOn and not self.outlinesOn:
|
| 908 |
+
self.p0.removeItem(self.layer)
|
| 909 |
+
self.layer_off = True
|
| 910 |
+
else:
|
| 911 |
+
if self.layer_off:
|
| 912 |
+
self.p0.addItem(self.layer)
|
| 913 |
+
self.draw_layer()
|
| 914 |
+
self.update_layer()
|
| 915 |
+
if self.loaded:
|
| 916 |
+
self.update_plot()
|
| 917 |
+
self.update_layer()
|
| 918 |
+
|
| 919 |
+
def make_viewbox(self):
|
| 920 |
+
self.p0 = guiparts.ViewBoxNoRightDrag(parent=self, lockAspect=True,
|
| 921 |
+
name="plot1", border=[100, 100,
|
| 922 |
+
100], invertY=True)
|
| 923 |
+
self.p0.setCursor(QtCore.Qt.CrossCursor)
|
| 924 |
+
self.brush_size = 3
|
| 925 |
+
self.win.addItem(self.p0, 0, 0, rowspan=1, colspan=1)
|
| 926 |
+
self.p0.setMenuEnabled(False)
|
| 927 |
+
self.p0.setMouseEnabled(x=True, y=True)
|
| 928 |
+
self.img = pg.ImageItem(viewbox=self.p0, parent=self)
|
| 929 |
+
self.img.autoDownsample = False
|
| 930 |
+
self.layer = guiparts.ImageDraw(viewbox=self.p0, parent=self)
|
| 931 |
+
self.layer.setLevels([0, 255])
|
| 932 |
+
self.scale = pg.ImageItem(viewbox=self.p0, parent=self)
|
| 933 |
+
self.scale.setLevels([0, 255])
|
| 934 |
+
self.p0.scene().contextMenuItem = self.p0
|
| 935 |
+
self.Ly, self.Lx = 512, 512
|
| 936 |
+
self.p0.addItem(self.img)
|
| 937 |
+
self.p0.addItem(self.layer)
|
| 938 |
+
self.p0.addItem(self.scale)
|
| 939 |
+
|
| 940 |
+
def reset(self):
|
| 941 |
+
# ---- start sets of points ---- #
|
| 942 |
+
self.selected = 0
|
| 943 |
+
self.nchan = 3
|
| 944 |
+
self.loaded = False
|
| 945 |
+
self.channel = [0, 1]
|
| 946 |
+
self.current_point_set = []
|
| 947 |
+
self.in_stroke = False
|
| 948 |
+
self.strokes = []
|
| 949 |
+
self.stroke_appended = True
|
| 950 |
+
self.resize = False
|
| 951 |
+
self.ncells.reset()
|
| 952 |
+
self.zdraw = []
|
| 953 |
+
self.removed_cell = []
|
| 954 |
+
self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
|
| 955 |
+
|
| 956 |
+
# -- zero out image stack -- #
|
| 957 |
+
self.opacity = 128 # how opaque masks should be
|
| 958 |
+
self.outcolor = [200, 200, 255, 200]
|
| 959 |
+
self.NZ, self.Ly, self.Lx = 1, 256, 256
|
| 960 |
+
self.saturation = self.saturation if hasattr(self, 'saturation') else []
|
| 961 |
+
|
| 962 |
+
# only adjust the saturation if auto-adjust is on:
|
| 963 |
+
if self.autobtn.isChecked():
|
| 964 |
+
for r in range(3):
|
| 965 |
+
self.saturation.append([[0, 255] for n in range(self.NZ)])
|
| 966 |
+
self.sliders[r].setValue([0, 255])
|
| 967 |
+
self.sliders[r].setEnabled(False)
|
| 968 |
+
self.sliders[r].show()
|
| 969 |
+
self.currentZ = 0
|
| 970 |
+
self.flows = [[], [], [], [], [[]]]
|
| 971 |
+
# masks matrix
|
| 972 |
+
# image matrix with a scale disk
|
| 973 |
+
self.stack = np.zeros((1, self.Ly, self.Lx, 3))
|
| 974 |
+
self.Lyr, self.Lxr = self.Ly, self.Lx
|
| 975 |
+
self.Ly0, self.Lx0 = self.Ly, self.Lx
|
| 976 |
+
self.radii = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
|
| 977 |
+
self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
|
| 978 |
+
self.cellpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
|
| 979 |
+
self.outpix = np.zeros((1, self.Ly, self.Lx), np.uint16)
|
| 980 |
+
self.ismanual = np.zeros(0, "bool")
|
| 981 |
+
|
| 982 |
+
# -- set menus to default -- #
|
| 983 |
+
self.color = 0
|
| 984 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 985 |
+
self.view = 0
|
| 986 |
+
self.ViewDropDown.setCurrentIndex(0)
|
| 987 |
+
self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
|
| 988 |
+
self.delete_restore()
|
| 989 |
+
|
| 990 |
+
self.clear_all()
|
| 991 |
+
|
| 992 |
+
self.filename = []
|
| 993 |
+
self.loaded = False
|
| 994 |
+
self.recompute_masks = False
|
| 995 |
+
|
| 996 |
+
self.deleting_multiple = False
|
| 997 |
+
self.removing_cells_list = []
|
| 998 |
+
self.removing_region = False
|
| 999 |
+
self.remove_roi_obj = None
|
| 1000 |
+
|
| 1001 |
+
def delete_restore(self):
|
| 1002 |
+
""" delete restored imgs but don't reset settings """
|
| 1003 |
+
if hasattr(self, "stack_filtered"):
|
| 1004 |
+
del self.stack_filtered
|
| 1005 |
+
if hasattr(self, "cellpix_orig"):
|
| 1006 |
+
self.cellpix = self.cellpix_orig.copy()
|
| 1007 |
+
self.outpix = self.outpix_orig.copy()
|
| 1008 |
+
del self.outpix_orig, self.outpix_resize
|
| 1009 |
+
del self.cellpix_orig, self.cellpix_resize
|
| 1010 |
+
|
| 1011 |
+
def clear_restore(self):
|
| 1012 |
+
""" delete restored imgs and reset settings """
|
| 1013 |
+
print("GUI_INFO: clearing restored image")
|
| 1014 |
+
self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False)
|
| 1015 |
+
if self.ViewDropDown.currentIndex() == self.ViewDropDown.count() - 1:
|
| 1016 |
+
self.ViewDropDown.setCurrentIndex(0)
|
| 1017 |
+
self.delete_restore()
|
| 1018 |
+
self.restore = None
|
| 1019 |
+
self.ratio = 1.
|
| 1020 |
+
self.set_normalize_params(self.get_normalize_params())
|
| 1021 |
+
|
| 1022 |
+
def brush_choose(self):
|
| 1023 |
+
self.brush_size = self.BrushChoose.currentIndex() * 2 + 1
|
| 1024 |
+
if self.loaded:
|
| 1025 |
+
self.layer.setDrawKernel(kernel_size=self.brush_size)
|
| 1026 |
+
self.update_layer()
|
| 1027 |
+
|
| 1028 |
+
def clear_all(self):
|
| 1029 |
+
self.prev_selected = 0
|
| 1030 |
+
self.selected = 0
|
| 1031 |
+
if self.restore and "upsample" in self.restore:
|
| 1032 |
+
self.layerz = 0 * np.ones((self.Lyr, self.Lxr, 4), np.uint8)
|
| 1033 |
+
self.cellpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
|
| 1034 |
+
self.outpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16)
|
| 1035 |
+
self.cellpix_resize = self.cellpix.copy()
|
| 1036 |
+
self.outpix_resize = self.outpix.copy()
|
| 1037 |
+
self.cellpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
|
| 1038 |
+
self.outpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16)
|
| 1039 |
+
else:
|
| 1040 |
+
self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8)
|
| 1041 |
+
self.cellpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
|
| 1042 |
+
self.outpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16)
|
| 1043 |
+
|
| 1044 |
+
self.cellcolors = np.array([255, 255, 255])[np.newaxis, :]
|
| 1045 |
+
self.ncells.reset()
|
| 1046 |
+
self.toggle_removals()
|
| 1047 |
+
self.update_scale()
|
| 1048 |
+
self.update_layer()
|
| 1049 |
+
|
| 1050 |
+
def select_cell(self, idx):
|
| 1051 |
+
self.prev_selected = self.selected
|
| 1052 |
+
self.selected = idx
|
| 1053 |
+
if self.selected > 0:
|
| 1054 |
+
z = self.currentZ
|
| 1055 |
+
self.layerz[self.cellpix[z] == idx] = np.array(
|
| 1056 |
+
[255, 255, 255, self.opacity])
|
| 1057 |
+
self.update_layer()
|
| 1058 |
+
|
| 1059 |
+
def select_cell_multi(self, idx):
|
| 1060 |
+
if idx > 0:
|
| 1061 |
+
z = self.currentZ
|
| 1062 |
+
self.layerz[self.cellpix[z] == idx] = np.array(
|
| 1063 |
+
[255, 255, 255, self.opacity])
|
| 1064 |
+
self.update_layer()
|
| 1065 |
+
|
| 1066 |
+
def unselect_cell(self):
|
| 1067 |
+
if self.selected > 0:
|
| 1068 |
+
idx = self.selected
|
| 1069 |
+
if idx < (self.ncells.get() + 1):
|
| 1070 |
+
z = self.currentZ
|
| 1071 |
+
self.layerz[self.cellpix[z] == idx] = np.append(
|
| 1072 |
+
self.cellcolors[idx], self.opacity)
|
| 1073 |
+
if self.outlinesOn:
|
| 1074 |
+
self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
|
| 1075 |
+
np.uint8)
|
| 1076 |
+
#[0,0,0,self.opacity])
|
| 1077 |
+
self.update_layer()
|
| 1078 |
+
self.selected = 0
|
| 1079 |
+
|
| 1080 |
+
def unselect_cell_multi(self, idx):
|
| 1081 |
+
z = self.currentZ
|
| 1082 |
+
self.layerz[self.cellpix[z] == idx] = np.append(self.cellcolors[idx],
|
| 1083 |
+
self.opacity)
|
| 1084 |
+
if self.outlinesOn:
|
| 1085 |
+
self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype(
|
| 1086 |
+
np.uint8)
|
| 1087 |
+
# [0,0,0,self.opacity])
|
| 1088 |
+
self.update_layer()
|
| 1089 |
+
|
| 1090 |
+
def remove_cell(self, idx):
|
| 1091 |
+
if isinstance(idx, (int, np.integer)):
|
| 1092 |
+
idx = [idx]
|
| 1093 |
+
# because the function remove_single_cell updates the state of the cellpix and outpix arrays
|
| 1094 |
+
# by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order
|
| 1095 |
+
# so that the indices are correct
|
| 1096 |
+
idx.sort(reverse=True)
|
| 1097 |
+
for i in idx:
|
| 1098 |
+
self.remove_single_cell(i)
|
| 1099 |
+
self.ncells -= len(idx) # _save_sets uses ncells
|
| 1100 |
+
self.update_layer()
|
| 1101 |
+
|
| 1102 |
+
if self.ncells == 0:
|
| 1103 |
+
self.ClearButton.setEnabled(False)
|
| 1104 |
+
if self.NZ == 1:
|
| 1105 |
+
io._save_sets_with_check(self)
|
| 1106 |
+
|
| 1107 |
+
|
| 1108 |
+
def remove_single_cell(self, idx):
|
| 1109 |
+
# remove from manual array
|
| 1110 |
+
self.selected = 0
|
| 1111 |
+
if self.NZ > 1:
|
| 1112 |
+
zextent = ((self.cellpix == idx).sum(axis=(1, 2)) > 0).nonzero()[0]
|
| 1113 |
+
else:
|
| 1114 |
+
zextent = [0]
|
| 1115 |
+
for z in zextent:
|
| 1116 |
+
cp = self.cellpix[z] == idx
|
| 1117 |
+
op = self.outpix[z] == idx
|
| 1118 |
+
# remove from self.cellpix and self.outpix
|
| 1119 |
+
self.cellpix[z, cp] = 0
|
| 1120 |
+
self.outpix[z, op] = 0
|
| 1121 |
+
if z == self.currentZ:
|
| 1122 |
+
# remove from mask layer
|
| 1123 |
+
self.layerz[cp] = np.array([0, 0, 0, 0])
|
| 1124 |
+
|
| 1125 |
+
# reduce other pixels by -1
|
| 1126 |
+
self.cellpix[self.cellpix > idx] -= 1
|
| 1127 |
+
self.outpix[self.outpix > idx] -= 1
|
| 1128 |
+
|
| 1129 |
+
if self.NZ == 1:
|
| 1130 |
+
self.removed_cell = [
|
| 1131 |
+
self.ismanual[idx - 1], self.cellcolors[idx],
|
| 1132 |
+
np.nonzero(cp),
|
| 1133 |
+
np.nonzero(op)
|
| 1134 |
+
]
|
| 1135 |
+
self.redo.setEnabled(True)
|
| 1136 |
+
ar, ac = self.removed_cell[2]
|
| 1137 |
+
d = datetime.datetime.now()
|
| 1138 |
+
self.track_changes.append(
|
| 1139 |
+
[d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]])
|
| 1140 |
+
# remove cell from lists
|
| 1141 |
+
self.ismanual = np.delete(self.ismanual, idx - 1)
|
| 1142 |
+
self.cellcolors = np.delete(self.cellcolors, [idx], axis=0)
|
| 1143 |
+
del self.zdraw[idx - 1]
|
| 1144 |
+
print("GUI_INFO: removed cell %d" % (idx - 1))
|
| 1145 |
+
|
| 1146 |
+
def remove_region_cells(self):
|
| 1147 |
+
if self.removing_cells_list:
|
| 1148 |
+
for idx in self.removing_cells_list:
|
| 1149 |
+
self.unselect_cell_multi(idx)
|
| 1150 |
+
self.removing_cells_list.clear()
|
| 1151 |
+
self.disable_buttons_removeROIs()
|
| 1152 |
+
self.removing_region = True
|
| 1153 |
+
|
| 1154 |
+
self.clear_multi_selected_cells()
|
| 1155 |
+
|
| 1156 |
+
# make roi region here in center of view, making ROI half the size of the view
|
| 1157 |
+
roi_width = self.p0.viewRect().width() / 2
|
| 1158 |
+
x_loc = self.p0.viewRect().x() + (roi_width / 2)
|
| 1159 |
+
roi_height = self.p0.viewRect().height() / 2
|
| 1160 |
+
y_loc = self.p0.viewRect().y() + (roi_height / 2)
|
| 1161 |
+
|
| 1162 |
+
pos = [x_loc, y_loc]
|
| 1163 |
+
roi = pg.RectROI(pos, [roi_width, roi_height], pen=pg.mkPen("y", width=2),
|
| 1164 |
+
removable=True)
|
| 1165 |
+
roi.sigRemoveRequested.connect(self.remove_roi)
|
| 1166 |
+
roi.sigRegionChangeFinished.connect(self.roi_changed)
|
| 1167 |
+
self.p0.addItem(roi)
|
| 1168 |
+
self.remove_roi_obj = roi
|
| 1169 |
+
self.roi_changed(roi)
|
| 1170 |
+
|
| 1171 |
+
def delete_multiple_cells(self):
|
| 1172 |
+
self.unselect_cell()
|
| 1173 |
+
self.disable_buttons_removeROIs()
|
| 1174 |
+
self.DoneDeleteMultipleROIButton.setEnabled(True)
|
| 1175 |
+
self.MakeDeletionRegionButton.setEnabled(True)
|
| 1176 |
+
self.CancelDeleteMultipleROIButton.setEnabled(True)
|
| 1177 |
+
self.deleting_multiple = True
|
| 1178 |
+
|
| 1179 |
+
def done_remove_multiple_cells(self):
|
| 1180 |
+
self.deleting_multiple = False
|
| 1181 |
+
self.removing_region = False
|
| 1182 |
+
self.DoneDeleteMultipleROIButton.setEnabled(False)
|
| 1183 |
+
self.MakeDeletionRegionButton.setEnabled(False)
|
| 1184 |
+
self.CancelDeleteMultipleROIButton.setEnabled(False)
|
| 1185 |
+
|
| 1186 |
+
if self.removing_cells_list:
|
| 1187 |
+
self.removing_cells_list = list(set(self.removing_cells_list))
|
| 1188 |
+
display_remove_list = [i - 1 for i in self.removing_cells_list]
|
| 1189 |
+
print(f"GUI_INFO: removing cells: {display_remove_list}")
|
| 1190 |
+
self.remove_cell(self.removing_cells_list)
|
| 1191 |
+
self.removing_cells_list.clear()
|
| 1192 |
+
self.unselect_cell()
|
| 1193 |
+
self.enable_buttons()
|
| 1194 |
+
|
| 1195 |
+
if self.remove_roi_obj is not None:
|
| 1196 |
+
self.remove_roi(self.remove_roi_obj)
|
| 1197 |
+
|
| 1198 |
+
def merge_cells(self, idx):
|
| 1199 |
+
self.prev_selected = self.selected
|
| 1200 |
+
self.selected = idx
|
| 1201 |
+
if self.selected != self.prev_selected:
|
| 1202 |
+
for z in range(self.NZ):
|
| 1203 |
+
ar0, ac0 = np.nonzero(self.cellpix[z] == self.prev_selected)
|
| 1204 |
+
ar1, ac1 = np.nonzero(self.cellpix[z] == self.selected)
|
| 1205 |
+
touching = np.logical_and((ar0[:, np.newaxis] - ar1) < 3,
|
| 1206 |
+
(ac0[:, np.newaxis] - ac1) < 3).sum()
|
| 1207 |
+
ar = np.hstack((ar0, ar1))
|
| 1208 |
+
ac = np.hstack((ac0, ac1))
|
| 1209 |
+
vr0, vc0 = np.nonzero(self.outpix[z] == self.prev_selected)
|
| 1210 |
+
vr1, vc1 = np.nonzero(self.outpix[z] == self.selected)
|
| 1211 |
+
self.outpix[z, vr0, vc0] = 0
|
| 1212 |
+
self.outpix[z, vr1, vc1] = 0
|
| 1213 |
+
if touching > 0:
|
| 1214 |
+
mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8)
|
| 1215 |
+
mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
|
| 1216 |
+
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 1217 |
+
cv2.CHAIN_APPROX_NONE)
|
| 1218 |
+
pvc, pvr = contours[-2][0].squeeze().T
|
| 1219 |
+
vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
|
| 1220 |
+
|
| 1221 |
+
else:
|
| 1222 |
+
vr = np.hstack((vr0, vr1))
|
| 1223 |
+
vc = np.hstack((vc0, vc1))
|
| 1224 |
+
color = self.cellcolors[self.prev_selected]
|
| 1225 |
+
self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected)
|
| 1226 |
+
self.remove_cell(self.selected)
|
| 1227 |
+
print("GUI_INFO: merged two cells")
|
| 1228 |
+
self.update_layer()
|
| 1229 |
+
io._save_sets_with_check(self)
|
| 1230 |
+
self.undo.setEnabled(False)
|
| 1231 |
+
self.redo.setEnabled(False)
|
| 1232 |
+
|
| 1233 |
+
def undo_remove_cell(self):
|
| 1234 |
+
if len(self.removed_cell) > 0:
|
| 1235 |
+
z = 0
|
| 1236 |
+
ar, ac = self.removed_cell[2]
|
| 1237 |
+
vr, vc = self.removed_cell[3]
|
| 1238 |
+
color = self.removed_cell[1]
|
| 1239 |
+
self.draw_mask(z, ar, ac, vr, vc, color)
|
| 1240 |
+
self.toggle_mask_ops()
|
| 1241 |
+
self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0)
|
| 1242 |
+
self.ncells += 1
|
| 1243 |
+
self.ismanual = np.append(self.ismanual, self.removed_cell[0])
|
| 1244 |
+
self.zdraw.append([])
|
| 1245 |
+
print(">>> added back removed cell")
|
| 1246 |
+
self.update_layer()
|
| 1247 |
+
io._save_sets_with_check(self)
|
| 1248 |
+
self.removed_cell = []
|
| 1249 |
+
self.redo.setEnabled(False)
|
| 1250 |
+
|
| 1251 |
+
def remove_stroke(self, delete_points=True, stroke_ind=-1):
|
| 1252 |
+
stroke = np.array(self.strokes[stroke_ind])
|
| 1253 |
+
cZ = self.currentZ
|
| 1254 |
+
inZ = stroke[0, 0] == cZ
|
| 1255 |
+
if inZ:
|
| 1256 |
+
outpix = self.outpix[cZ, stroke[:, 1], stroke[:, 2]] > 0
|
| 1257 |
+
self.layerz[stroke[~outpix, 1], stroke[~outpix, 2]] = np.array([0, 0, 0, 0])
|
| 1258 |
+
cellpix = self.cellpix[cZ, stroke[:, 1], stroke[:, 2]]
|
| 1259 |
+
ccol = self.cellcolors.copy()
|
| 1260 |
+
if self.selected > 0:
|
| 1261 |
+
ccol[self.selected] = np.array([255, 255, 255])
|
| 1262 |
+
col2mask = ccol[cellpix]
|
| 1263 |
+
if self.masksOn:
|
| 1264 |
+
col2mask = np.concatenate(
|
| 1265 |
+
(col2mask, self.opacity * (cellpix[:, np.newaxis] > 0)), axis=-1)
|
| 1266 |
+
else:
|
| 1267 |
+
col2mask = np.concatenate((col2mask, 0 * (cellpix[:, np.newaxis] > 0)),
|
| 1268 |
+
axis=-1)
|
| 1269 |
+
self.layerz[stroke[:, 1], stroke[:, 2], :] = col2mask
|
| 1270 |
+
if self.outlinesOn:
|
| 1271 |
+
self.layerz[stroke[outpix, 1], stroke[outpix,
|
| 1272 |
+
2]] = np.array(self.outcolor)
|
| 1273 |
+
if delete_points:
|
| 1274 |
+
del self.current_point_set[stroke_ind]
|
| 1275 |
+
self.update_layer()
|
| 1276 |
+
|
| 1277 |
+
del self.strokes[stroke_ind]
|
| 1278 |
+
|
| 1279 |
+
def plot_clicked(self, event):
|
| 1280 |
+
if event.button()==QtCore.Qt.LeftButton \
|
| 1281 |
+
and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
|
| 1282 |
+
and not self.removing_region:
|
| 1283 |
+
if event.double():
|
| 1284 |
+
try:
|
| 1285 |
+
self.p0.setYRange(0, self.Ly + self.pr)
|
| 1286 |
+
except:
|
| 1287 |
+
self.p0.setYRange(0, self.Ly)
|
| 1288 |
+
self.p0.setXRange(0, self.Lx)
|
| 1289 |
+
|
| 1290 |
+
def cancel_remove_multiple(self):
|
| 1291 |
+
self.clear_multi_selected_cells()
|
| 1292 |
+
self.done_remove_multiple_cells()
|
| 1293 |
+
|
| 1294 |
+
def clear_multi_selected_cells(self):
|
| 1295 |
+
# unselect all previously selected cells:
|
| 1296 |
+
for idx in self.removing_cells_list:
|
| 1297 |
+
self.unselect_cell_multi(idx)
|
| 1298 |
+
self.removing_cells_list.clear()
|
| 1299 |
+
|
| 1300 |
+
def add_roi(self, roi):
|
| 1301 |
+
self.p0.addItem(roi)
|
| 1302 |
+
self.remove_roi_obj = roi
|
| 1303 |
+
|
| 1304 |
+
def remove_roi(self, roi):
|
| 1305 |
+
self.clear_multi_selected_cells()
|
| 1306 |
+
assert roi == self.remove_roi_obj
|
| 1307 |
+
self.remove_roi_obj = None
|
| 1308 |
+
self.p0.removeItem(roi)
|
| 1309 |
+
self.removing_region = False
|
| 1310 |
+
|
| 1311 |
+
def roi_changed(self, roi):
|
| 1312 |
+
# find the overlapping cells and make them selected
|
| 1313 |
+
pos = roi.pos()
|
| 1314 |
+
size = roi.size()
|
| 1315 |
+
x0 = int(pos.x())
|
| 1316 |
+
y0 = int(pos.y())
|
| 1317 |
+
x1 = int(pos.x() + size.x())
|
| 1318 |
+
y1 = int(pos.y() + size.y())
|
| 1319 |
+
if x0 < 0:
|
| 1320 |
+
x0 = 0
|
| 1321 |
+
if y0 < 0:
|
| 1322 |
+
y0 = 0
|
| 1323 |
+
if x1 > self.Lx:
|
| 1324 |
+
x1 = self.Lx
|
| 1325 |
+
if y1 > self.Ly:
|
| 1326 |
+
y1 = self.Ly
|
| 1327 |
+
|
| 1328 |
+
# find cells in that region
|
| 1329 |
+
cell_idxs = np.unique(self.cellpix[self.currentZ, y0:y1, x0:x1])
|
| 1330 |
+
cell_idxs = np.trim_zeros(cell_idxs)
|
| 1331 |
+
# deselect cells not in region by deselecting all and then selecting the ones in the region
|
| 1332 |
+
self.clear_multi_selected_cells()
|
| 1333 |
+
|
| 1334 |
+
for idx in cell_idxs:
|
| 1335 |
+
self.select_cell_multi(idx)
|
| 1336 |
+
self.removing_cells_list.append(idx)
|
| 1337 |
+
|
| 1338 |
+
self.update_layer()
|
| 1339 |
+
|
| 1340 |
+
def mouse_moved(self, pos):
|
| 1341 |
+
items = self.win.scene().items(pos)
|
| 1342 |
+
|
| 1343 |
+
def color_choose(self):
|
| 1344 |
+
self.color = self.RGBDropDown.currentIndex()
|
| 1345 |
+
self.view = 0
|
| 1346 |
+
self.ViewDropDown.setCurrentIndex(self.view)
|
| 1347 |
+
self.update_plot()
|
| 1348 |
+
|
| 1349 |
+
def update_plot(self):
|
| 1350 |
+
self.view = self.ViewDropDown.currentIndex()
|
| 1351 |
+
self.Ly, self.Lx, _ = self.stack[self.currentZ].shape
|
| 1352 |
+
|
| 1353 |
+
if self.view == 0 or self.view == self.ViewDropDown.count() - 1:
|
| 1354 |
+
image = self.stack[
|
| 1355 |
+
self.currentZ] if self.view == 0 else self.stack_filtered[self.currentZ]
|
| 1356 |
+
if self.color == 0:
|
| 1357 |
+
self.img.setImage(image, autoLevels=False, lut=None)
|
| 1358 |
+
if self.nchan > 1:
|
| 1359 |
+
levels = np.array([
|
| 1360 |
+
self.saturation[0][self.currentZ],
|
| 1361 |
+
self.saturation[1][self.currentZ],
|
| 1362 |
+
self.saturation[2][self.currentZ]
|
| 1363 |
+
])
|
| 1364 |
+
self.img.setLevels(levels)
|
| 1365 |
+
else:
|
| 1366 |
+
self.img.setLevels(self.saturation[0][self.currentZ])
|
| 1367 |
+
elif self.color > 0 and self.color < 4:
|
| 1368 |
+
if self.nchan > 1:
|
| 1369 |
+
image = image[:, :, self.color - 1]
|
| 1370 |
+
self.img.setImage(image, autoLevels=False, lut=self.cmap[self.color])
|
| 1371 |
+
if self.nchan > 1:
|
| 1372 |
+
self.img.setLevels(self.saturation[self.color - 1][self.currentZ])
|
| 1373 |
+
else:
|
| 1374 |
+
self.img.setLevels(self.saturation[0][self.currentZ])
|
| 1375 |
+
elif self.color == 4:
|
| 1376 |
+
if self.nchan > 1:
|
| 1377 |
+
image = image.mean(axis=-1)
|
| 1378 |
+
self.img.setImage(image, autoLevels=False, lut=None)
|
| 1379 |
+
self.img.setLevels(self.saturation[0][self.currentZ])
|
| 1380 |
+
elif self.color == 5:
|
| 1381 |
+
if self.nchan > 1:
|
| 1382 |
+
image = image.mean(axis=-1)
|
| 1383 |
+
self.img.setImage(image, autoLevels=False, lut=self.cmap[0])
|
| 1384 |
+
self.img.setLevels(self.saturation[0][self.currentZ])
|
| 1385 |
+
else:
|
| 1386 |
+
image = np.zeros((self.Ly, self.Lx), np.uint8)
|
| 1387 |
+
if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0:
|
| 1388 |
+
image = self.flows[self.view - 1][self.currentZ]
|
| 1389 |
+
if self.view > 1:
|
| 1390 |
+
self.img.setImage(image, autoLevels=False, lut=self.bwr)
|
| 1391 |
+
else:
|
| 1392 |
+
self.img.setImage(image, autoLevels=False, lut=None)
|
| 1393 |
+
self.img.setLevels([0.0, 255.0])
|
| 1394 |
+
|
| 1395 |
+
for r in range(3):
|
| 1396 |
+
self.sliders[r].setValue([
|
| 1397 |
+
self.saturation[r][self.currentZ][0],
|
| 1398 |
+
self.saturation[r][self.currentZ][1]
|
| 1399 |
+
])
|
| 1400 |
+
self.win.show()
|
| 1401 |
+
self.show()
|
| 1402 |
+
|
| 1403 |
+
|
| 1404 |
+
def update_layer(self):
|
| 1405 |
+
if self.masksOn or self.outlinesOn:
|
| 1406 |
+
self.layer.setImage(self.layerz, autoLevels=False)
|
| 1407 |
+
self.win.show()
|
| 1408 |
+
self.show()
|
| 1409 |
+
|
| 1410 |
+
|
| 1411 |
+
def add_set(self):
|
| 1412 |
+
if len(self.current_point_set) > 0:
|
| 1413 |
+
while len(self.strokes) > 0:
|
| 1414 |
+
self.remove_stroke(delete_points=False)
|
| 1415 |
+
if len(self.current_point_set[0]) > 8:
|
| 1416 |
+
color = self.colormap[self.ncells.get(), :3]
|
| 1417 |
+
median = self.add_mask(points=self.current_point_set, color=color)
|
| 1418 |
+
if median is not None:
|
| 1419 |
+
self.removed_cell = []
|
| 1420 |
+
self.toggle_mask_ops()
|
| 1421 |
+
self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :],
|
| 1422 |
+
axis=0)
|
| 1423 |
+
self.ncells += 1
|
| 1424 |
+
self.ismanual = np.append(self.ismanual, True)
|
| 1425 |
+
if self.NZ == 1:
|
| 1426 |
+
# only save after each cell if single image
|
| 1427 |
+
io._save_sets_with_check(self)
|
| 1428 |
+
else:
|
| 1429 |
+
print("GUI_ERROR: cell too small, not drawn")
|
| 1430 |
+
self.current_stroke = []
|
| 1431 |
+
self.strokes = []
|
| 1432 |
+
self.current_point_set = []
|
| 1433 |
+
self.update_layer()
|
| 1434 |
+
|
| 1435 |
+
def add_mask(self, points=None, color=(100, 200, 50), dense=True):
|
| 1436 |
+
# points is list of strokes
|
| 1437 |
+
points_all = np.concatenate(points, axis=0)
|
| 1438 |
+
|
| 1439 |
+
# loop over z values
|
| 1440 |
+
median = []
|
| 1441 |
+
zdraw = np.unique(points_all[:, 0])
|
| 1442 |
+
z = 0
|
| 1443 |
+
ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
|
| 1444 |
+
0, "int"), np.zeros(0, "int")
|
| 1445 |
+
for stroke in points:
|
| 1446 |
+
stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
|
| 1447 |
+
vr = stroke[:, 1]
|
| 1448 |
+
vc = stroke[:, 2]
|
| 1449 |
+
# get points inside drawn points
|
| 1450 |
+
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
|
| 1451 |
+
pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
|
| 1452 |
+
axis=-1)[:, np.newaxis, :]
|
| 1453 |
+
mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
|
| 1454 |
+
ar, ac = np.nonzero(mask)
|
| 1455 |
+
ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
|
| 1456 |
+
# get dense outline
|
| 1457 |
+
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE)
|
| 1458 |
+
pvc, pvr = contours[-2][0][:,0].T
|
| 1459 |
+
vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
|
| 1460 |
+
# concatenate all points
|
| 1461 |
+
ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
|
| 1462 |
+
# if these pixels are overlapping with another cell, reassign them
|
| 1463 |
+
ioverlap = self.cellpix[z][ar, ac] > 0
|
| 1464 |
+
if (~ioverlap).sum() < 10:
|
| 1465 |
+
print("GUI_ERROR: cell < 10 pixels without overlaps, not drawn")
|
| 1466 |
+
return None
|
| 1467 |
+
elif ioverlap.sum() > 0:
|
| 1468 |
+
ar, ac = ar[~ioverlap], ac[~ioverlap]
|
| 1469 |
+
# compute outline of new mask
|
| 1470 |
+
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8)
|
| 1471 |
+
mask[ar - vr.min() + 2, ac - vc.min() + 2] = 1
|
| 1472 |
+
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 1473 |
+
cv2.CHAIN_APPROX_NONE)
|
| 1474 |
+
pvc, pvr = contours[-2][0][:,0].T
|
| 1475 |
+
vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
|
| 1476 |
+
ars = np.concatenate((ars, ar), axis=0)
|
| 1477 |
+
acs = np.concatenate((acs, ac), axis=0)
|
| 1478 |
+
vrs = np.concatenate((vrs, vr), axis=0)
|
| 1479 |
+
vcs = np.concatenate((vcs, vc), axis=0)
|
| 1480 |
+
|
| 1481 |
+
self.draw_mask(z, ars, acs, vrs, vcs, color)
|
| 1482 |
+
median.append(np.array([np.median(ars), np.median(acs)]))
|
| 1483 |
+
|
| 1484 |
+
self.zdraw.append(zdraw)
|
| 1485 |
+
d = datetime.datetime.now()
|
| 1486 |
+
self.track_changes.append(
|
| 1487 |
+
[d.strftime("%m/%d/%Y, %H:%M:%S"), "added mask", [ar, ac]])
|
| 1488 |
+
return median
|
| 1489 |
+
|
| 1490 |
+
def draw_mask(self, z, ar, ac, vr, vc, color, idx=None):
|
| 1491 |
+
""" draw single mask using outlines and area """
|
| 1492 |
+
if idx is None:
|
| 1493 |
+
idx = self.ncells + 1
|
| 1494 |
+
self.cellpix[z, vr, vc] = idx
|
| 1495 |
+
self.cellpix[z, ar, ac] = idx
|
| 1496 |
+
self.outpix[z, vr, vc] = idx
|
| 1497 |
+
if self.restore and "upsample" in self.restore:
|
| 1498 |
+
if self.resize:
|
| 1499 |
+
self.cellpix_resize[z, vr, vc] = idx
|
| 1500 |
+
self.cellpix_resize[z, ar, ac] = idx
|
| 1501 |
+
self.outpix_resize[z, vr, vc] = idx
|
| 1502 |
+
self.cellpix_orig[z, (vr / self.ratio).astype(int),
|
| 1503 |
+
(vc / self.ratio).astype(int)] = idx
|
| 1504 |
+
self.cellpix_orig[z, (ar / self.ratio).astype(int),
|
| 1505 |
+
(ac / self.ratio).astype(int)] = idx
|
| 1506 |
+
self.outpix_orig[z, (vr / self.ratio).astype(int),
|
| 1507 |
+
(vc / self.ratio).astype(int)] = idx
|
| 1508 |
+
else:
|
| 1509 |
+
self.cellpix_orig[z, vr, vc] = idx
|
| 1510 |
+
self.cellpix_orig[z, ar, ac] = idx
|
| 1511 |
+
self.outpix_orig[z, vr, vc] = idx
|
| 1512 |
+
|
| 1513 |
+
# get upsampled mask
|
| 1514 |
+
vrr = (vr.copy() * self.ratio).astype(int)
|
| 1515 |
+
vcr = (vc.copy() * self.ratio).astype(int)
|
| 1516 |
+
mask = np.zeros((np.ptp(vrr) + 4, np.ptp(vcr) + 4), np.uint8)
|
| 1517 |
+
pts = np.stack((vcr - vcr.min() + 2, vrr - vrr.min() + 2),
|
| 1518 |
+
axis=-1)[:, np.newaxis, :]
|
| 1519 |
+
mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
|
| 1520 |
+
arr, acr = np.nonzero(mask)
|
| 1521 |
+
arr, acr = arr + vrr.min() - 2, acr + vcr.min() - 2
|
| 1522 |
+
# get dense outline
|
| 1523 |
+
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 1524 |
+
cv2.CHAIN_APPROX_NONE)
|
| 1525 |
+
pvc, pvr = contours[-2][0].squeeze().T
|
| 1526 |
+
vrr, vcr = pvr + vrr.min() - 2, pvc + vcr.min() - 2
|
| 1527 |
+
# concatenate all points
|
| 1528 |
+
arr, acr = np.hstack((np.vstack((vrr, vcr)), np.vstack((arr, acr))))
|
| 1529 |
+
self.cellpix_resize[z, vrr, vcr] = idx
|
| 1530 |
+
self.cellpix_resize[z, arr, acr] = idx
|
| 1531 |
+
self.outpix_resize[z, vrr, vcr] = idx
|
| 1532 |
+
|
| 1533 |
+
if z == self.currentZ:
|
| 1534 |
+
self.layerz[ar, ac, :3] = color
|
| 1535 |
+
if self.masksOn:
|
| 1536 |
+
self.layerz[ar, ac, -1] = self.opacity
|
| 1537 |
+
if self.outlinesOn:
|
| 1538 |
+
self.layerz[vr, vc] = np.array(self.outcolor)
|
| 1539 |
+
|
| 1540 |
+
def compute_scale(self):
|
| 1541 |
+
# get diameter from gui
|
| 1542 |
+
diameter = self.segmentation_settings.diameter
|
| 1543 |
+
if not diameter:
|
| 1544 |
+
diameter = 30
|
| 1545 |
+
|
| 1546 |
+
self.pr = int(diameter)
|
| 1547 |
+
self.radii_padding = int(self.pr * 1.25)
|
| 1548 |
+
self.radii = np.zeros((self.Ly + self.radii_padding, self.Lx, 4), np.uint8)
|
| 1549 |
+
yy, xx = disk([self.Ly + self.radii_padding / 2 - 1, self.pr / 2 + 1],
|
| 1550 |
+
self.pr / 2, self.Ly + self.radii_padding, self.Lx)
|
| 1551 |
+
# rgb(150,50,150)
|
| 1552 |
+
self.radii[yy, xx, 0] = 150
|
| 1553 |
+
self.radii[yy, xx, 1] = 50
|
| 1554 |
+
self.radii[yy, xx, 2] = 150
|
| 1555 |
+
self.radii[yy, xx, 3] = 255
|
| 1556 |
+
self.p0.setYRange(0, self.Ly + self.radii_padding)
|
| 1557 |
+
self.p0.setXRange(0, self.Lx)
|
| 1558 |
+
|
| 1559 |
+
def update_scale(self):
|
| 1560 |
+
self.compute_scale()
|
| 1561 |
+
self.scale.setImage(self.radii, autoLevels=False)
|
| 1562 |
+
self.scale.setLevels([0.0, 255.0])
|
| 1563 |
+
self.win.show()
|
| 1564 |
+
self.show()
|
| 1565 |
+
|
| 1566 |
+
|
| 1567 |
+
def draw_layer(self):
|
| 1568 |
+
if self.resize:
|
| 1569 |
+
self.Ly, self.Lx = self.Lyr, self.Lxr
|
| 1570 |
+
else:
|
| 1571 |
+
self.Ly, self.Lx = self.Ly0, self.Lx0
|
| 1572 |
+
|
| 1573 |
+
if self.masksOn or self.outlinesOn:
|
| 1574 |
+
if self.restore and "upsample" in self.restore:
|
| 1575 |
+
if self.resize:
|
| 1576 |
+
self.cellpix = self.cellpix_resize.copy()
|
| 1577 |
+
self.outpix = self.outpix_resize.copy()
|
| 1578 |
+
else:
|
| 1579 |
+
self.cellpix = self.cellpix_orig.copy()
|
| 1580 |
+
self.outpix = self.outpix_orig.copy()
|
| 1581 |
+
|
| 1582 |
+
self.layerz = np.zeros((self.Ly, self.Lx, 4), np.uint8)
|
| 1583 |
+
if self.masksOn:
|
| 1584 |
+
self.layerz[..., :3] = self.cellcolors[self.cellpix[self.currentZ], :]
|
| 1585 |
+
self.layerz[..., 3] = self.opacity * (self.cellpix[self.currentZ]
|
| 1586 |
+
> 0).astype(np.uint8)
|
| 1587 |
+
if self.selected > 0:
|
| 1588 |
+
self.layerz[self.cellpix[self.currentZ] == self.selected] = np.array(
|
| 1589 |
+
[255, 255, 255, self.opacity])
|
| 1590 |
+
cZ = self.currentZ
|
| 1591 |
+
stroke_z = np.array([s[0][0] for s in self.strokes])
|
| 1592 |
+
inZ = np.nonzero(stroke_z == cZ)[0]
|
| 1593 |
+
if len(inZ) > 0:
|
| 1594 |
+
for i in inZ:
|
| 1595 |
+
stroke = np.array(self.strokes[i])
|
| 1596 |
+
self.layerz[stroke[:, 1], stroke[:,
|
| 1597 |
+
2]] = np.array([255, 0, 255, 100])
|
| 1598 |
+
else:
|
| 1599 |
+
self.layerz[..., 3] = 0
|
| 1600 |
+
|
| 1601 |
+
if self.outlinesOn:
|
| 1602 |
+
self.layerz[self.outpix[self.currentZ] > 0] = np.array(
|
| 1603 |
+
self.outcolor).astype(np.uint8)
|
| 1604 |
+
|
| 1605 |
+
|
| 1606 |
+
def set_normalize_params(self, normalize_params):
|
| 1607 |
+
from cellpose.models import normalize_default
|
| 1608 |
+
if self.restore != "filter":
|
| 1609 |
+
keys = list(normalize_params.keys()).copy()
|
| 1610 |
+
for key in keys:
|
| 1611 |
+
if key != "percentile":
|
| 1612 |
+
normalize_params[key] = normalize_default[key]
|
| 1613 |
+
normalize_params = {**normalize_default, **normalize_params}
|
| 1614 |
+
out = self.check_filter_params(normalize_params["sharpen_radius"],
|
| 1615 |
+
normalize_params["smooth_radius"],
|
| 1616 |
+
normalize_params["tile_norm_blocksize"],
|
| 1617 |
+
normalize_params["tile_norm_smooth3D"],
|
| 1618 |
+
normalize_params["norm3D"],
|
| 1619 |
+
normalize_params["invert"])
|
| 1620 |
+
|
| 1621 |
+
|
| 1622 |
+
def check_filter_params(self, sharpen, smooth, tile_norm, smooth3D, norm3D, invert):
|
| 1623 |
+
tile_norm = 0 if tile_norm < 0 else tile_norm
|
| 1624 |
+
sharpen = 0 if sharpen < 0 else sharpen
|
| 1625 |
+
smooth = 0 if smooth < 0 else smooth
|
| 1626 |
+
smooth3D = 0 if smooth3D < 0 else smooth3D
|
| 1627 |
+
norm3D = bool(norm3D)
|
| 1628 |
+
invert = bool(invert)
|
| 1629 |
+
if tile_norm > self.Ly and tile_norm > self.Lx:
|
| 1630 |
+
print(
|
| 1631 |
+
"GUI_ERROR: tile size (tile_norm) bigger than both image dimensions, disabling"
|
| 1632 |
+
)
|
| 1633 |
+
tile_norm = 0
|
| 1634 |
+
self.filt_edits[0].setText(str(sharpen))
|
| 1635 |
+
self.filt_edits[1].setText(str(smooth))
|
| 1636 |
+
self.filt_edits[2].setText(str(tile_norm))
|
| 1637 |
+
self.filt_edits[3].setText(str(smooth3D))
|
| 1638 |
+
self.norm3D_cb.setChecked(norm3D)
|
| 1639 |
+
return sharpen, smooth, tile_norm, smooth3D, norm3D, invert
|
| 1640 |
+
|
| 1641 |
+
def get_normalize_params(self):
|
| 1642 |
+
percentile = [
|
| 1643 |
+
self.segmentation_settings.low_percentile,
|
| 1644 |
+
self.segmentation_settings.high_percentile,
|
| 1645 |
+
]
|
| 1646 |
+
normalize_params = {"percentile": percentile}
|
| 1647 |
+
norm3D = self.norm3D_cb.isChecked()
|
| 1648 |
+
normalize_params["norm3D"] = norm3D
|
| 1649 |
+
sharpen = float(self.filt_edits[0].text())
|
| 1650 |
+
smooth = float(self.filt_edits[1].text())
|
| 1651 |
+
tile_norm = float(self.filt_edits[2].text())
|
| 1652 |
+
smooth3D = float(self.filt_edits[3].text())
|
| 1653 |
+
invert = False
|
| 1654 |
+
out = self.check_filter_params(sharpen, smooth, tile_norm, smooth3D, norm3D,
|
| 1655 |
+
invert)
|
| 1656 |
+
sharpen, smooth, tile_norm, smooth3D, norm3D, invert = out
|
| 1657 |
+
normalize_params["sharpen_radius"] = sharpen
|
| 1658 |
+
normalize_params["smooth_radius"] = smooth
|
| 1659 |
+
normalize_params["tile_norm_blocksize"] = tile_norm
|
| 1660 |
+
normalize_params["tile_norm_smooth3D"] = smooth3D
|
| 1661 |
+
normalize_params["invert"] = invert
|
| 1662 |
+
|
| 1663 |
+
from cellpose.models import normalize_default
|
| 1664 |
+
normalize_params = {**normalize_default, **normalize_params}
|
| 1665 |
+
|
| 1666 |
+
return normalize_params
|
| 1667 |
+
|
| 1668 |
+
def compute_saturation_if_checked(self):
|
| 1669 |
+
if self.autobtn.isChecked():
|
| 1670 |
+
self.compute_saturation()
|
| 1671 |
+
|
| 1672 |
+
def compute_saturation(self, return_img=False):
|
| 1673 |
+
norm = self.get_normalize_params()
|
| 1674 |
+
print(norm)
|
| 1675 |
+
sharpen, smooth = norm["sharpen_radius"], norm["smooth_radius"]
|
| 1676 |
+
percentile = norm["percentile"]
|
| 1677 |
+
tile_norm = norm["tile_norm_blocksize"]
|
| 1678 |
+
invert = norm["invert"]
|
| 1679 |
+
norm3D = norm["norm3D"]
|
| 1680 |
+
smooth3D = norm["tile_norm_smooth3D"]
|
| 1681 |
+
tile_norm = norm["tile_norm_blocksize"]
|
| 1682 |
+
|
| 1683 |
+
if sharpen > 0 or smooth > 0 or tile_norm > 0:
|
| 1684 |
+
img_norm = self.stack.copy()
|
| 1685 |
+
else:
|
| 1686 |
+
img_norm = self.stack
|
| 1687 |
+
|
| 1688 |
+
if sharpen > 0 or smooth > 0 or tile_norm > 0:
|
| 1689 |
+
self.restore = "filter"
|
| 1690 |
+
print(
|
| 1691 |
+
"GUI_INFO: computing filtered image because sharpen > 0 or tile_norm > 0"
|
| 1692 |
+
)
|
| 1693 |
+
print(
|
| 1694 |
+
"GUI_WARNING: will use memory to create filtered image -- make sure to have RAM for this"
|
| 1695 |
+
)
|
| 1696 |
+
img_norm = self.stack.copy()
|
| 1697 |
+
if sharpen > 0 or smooth > 0:
|
| 1698 |
+
img_norm = smooth_sharpen_img(self.stack, sharpen_radius=sharpen,
|
| 1699 |
+
smooth_radius=smooth)
|
| 1700 |
+
|
| 1701 |
+
if tile_norm > 0:
|
| 1702 |
+
img_norm = normalize99_tile(img_norm, blocksize=tile_norm,
|
| 1703 |
+
lower=percentile[0], upper=percentile[1],
|
| 1704 |
+
smooth3D=smooth3D, norm3D=norm3D)
|
| 1705 |
+
# convert to 0->255
|
| 1706 |
+
img_norm_min = img_norm.min()
|
| 1707 |
+
img_norm_max = img_norm.max()
|
| 1708 |
+
for c in range(img_norm.shape[-1]):
|
| 1709 |
+
if np.ptp(img_norm[..., c]) > 1e-3:
|
| 1710 |
+
img_norm[..., c] -= img_norm_min
|
| 1711 |
+
img_norm[..., c] /= (img_norm_max - img_norm_min)
|
| 1712 |
+
img_norm *= 255
|
| 1713 |
+
self.stack_filtered = img_norm
|
| 1714 |
+
self.ViewDropDown.model().item(self.ViewDropDown.count() -
|
| 1715 |
+
1).setEnabled(True)
|
| 1716 |
+
self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1)
|
| 1717 |
+
else:
|
| 1718 |
+
img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered
|
| 1719 |
+
|
| 1720 |
+
if self.autobtn.isChecked():
|
| 1721 |
+
self.saturation = []
|
| 1722 |
+
for c in range(img_norm.shape[-1]):
|
| 1723 |
+
self.saturation.append([])
|
| 1724 |
+
if np.ptp(img_norm[..., c]) > 1e-3:
|
| 1725 |
+
if norm3D:
|
| 1726 |
+
x01 = np.percentile(img_norm[..., c], percentile[0])
|
| 1727 |
+
x99 = np.percentile(img_norm[..., c], percentile[1])
|
| 1728 |
+
if invert:
|
| 1729 |
+
x01i = 255. - x99
|
| 1730 |
+
x99i = 255. - x01
|
| 1731 |
+
x01, x99 = x01i, x99i
|
| 1732 |
+
for n in range(self.NZ):
|
| 1733 |
+
self.saturation[-1].append([x01, x99])
|
| 1734 |
+
else:
|
| 1735 |
+
for z in range(self.NZ):
|
| 1736 |
+
if self.NZ > 1:
|
| 1737 |
+
x01 = np.percentile(img_norm[z, :, :, c], percentile[0])
|
| 1738 |
+
x99 = np.percentile(img_norm[z, :, :, c], percentile[1])
|
| 1739 |
+
else:
|
| 1740 |
+
x01 = np.percentile(img_norm[..., c], percentile[0])
|
| 1741 |
+
x99 = np.percentile(img_norm[..., c], percentile[1])
|
| 1742 |
+
if invert:
|
| 1743 |
+
x01i = 255. - x99
|
| 1744 |
+
x99i = 255. - x01
|
| 1745 |
+
x01, x99 = x01i, x99i
|
| 1746 |
+
self.saturation[-1].append([x01, x99])
|
| 1747 |
+
else:
|
| 1748 |
+
for n in range(self.NZ):
|
| 1749 |
+
self.saturation[-1].append([0, 255.])
|
| 1750 |
+
print(self.saturation[2][self.currentZ])
|
| 1751 |
+
|
| 1752 |
+
if img_norm.shape[-1] == 1:
|
| 1753 |
+
self.saturation.append(self.saturation[0])
|
| 1754 |
+
self.saturation.append(self.saturation[0])
|
| 1755 |
+
|
| 1756 |
+
# self.autobtn.setChecked(True)
|
| 1757 |
+
self.update_plot()
|
| 1758 |
+
|
| 1759 |
+
|
| 1760 |
+
def get_model_path(self, custom=False):
|
| 1761 |
+
if custom:
|
| 1762 |
+
self.current_model = self.ModelChooseC.currentText()
|
| 1763 |
+
self.current_model_path = os.fspath(
|
| 1764 |
+
models.MODEL_DIR.joinpath(self.current_model))
|
| 1765 |
+
else:
|
| 1766 |
+
self.current_model = "cpsam"
|
| 1767 |
+
self.current_model_path = models.model_path(self.current_model)
|
| 1768 |
+
|
| 1769 |
+
def initialize_model(self, model_name=None, custom=False):
|
| 1770 |
+
if model_name is None or custom:
|
| 1771 |
+
self.get_model_path(custom=custom)
|
| 1772 |
+
if not os.path.exists(self.current_model_path):
|
| 1773 |
+
raise ValueError("need to specify model (use dropdown)")
|
| 1774 |
+
|
| 1775 |
+
if model_name is None or not isinstance(model_name, str):
|
| 1776 |
+
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
|
| 1777 |
+
pretrained_model=self.current_model_path)
|
| 1778 |
+
else:
|
| 1779 |
+
self.current_model = model_name
|
| 1780 |
+
self.current_model_path = os.fspath(
|
| 1781 |
+
models.MODEL_DIR.joinpath(self.current_model))
|
| 1782 |
+
|
| 1783 |
+
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
|
| 1784 |
+
pretrained_model=self.current_model)
|
| 1785 |
+
|
| 1786 |
+
def add_model(self):
|
| 1787 |
+
io._add_model(self)
|
| 1788 |
+
return
|
| 1789 |
+
|
| 1790 |
+
def remove_model(self):
|
| 1791 |
+
io._remove_model(self)
|
| 1792 |
+
return
|
| 1793 |
+
|
| 1794 |
+
def new_model(self):
|
| 1795 |
+
if self.NZ != 1:
|
| 1796 |
+
print("ERROR: cannot train model on 3D data")
|
| 1797 |
+
return
|
| 1798 |
+
|
| 1799 |
+
# train model
|
| 1800 |
+
image_names = self.get_files()[0]
|
| 1801 |
+
self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set(
|
| 1802 |
+
image_names)
|
| 1803 |
+
TW = guiparts.TrainWindow(self, models.MODEL_NAMES)
|
| 1804 |
+
train = TW.exec_()
|
| 1805 |
+
if train:
|
| 1806 |
+
self.logger.info(
|
| 1807 |
+
f"training with {[os.path.split(f)[1] for f in self.train_files]}")
|
| 1808 |
+
self.train_model(restore=restore, normalize_params=normalize_params)
|
| 1809 |
+
else:
|
| 1810 |
+
print("GUI_INFO: training cancelled")
|
| 1811 |
+
|
| 1812 |
+
def train_model(self, restore=None, normalize_params=None):
|
| 1813 |
+
from cellpose.models import normalize_default
|
| 1814 |
+
if normalize_params is None:
|
| 1815 |
+
normalize_params = copy.deepcopy(normalize_default)
|
| 1816 |
+
model_type = models.MODEL_NAMES[self.training_params["model_index"]]
|
| 1817 |
+
self.logger.info(f"training new model starting at model {model_type}")
|
| 1818 |
+
self.current_model = model_type
|
| 1819 |
+
|
| 1820 |
+
self.model = models.CellposeModel(gpu=self.useGPU.isChecked(),
|
| 1821 |
+
model_type=model_type)
|
| 1822 |
+
save_path = os.path.dirname(self.filename)
|
| 1823 |
+
|
| 1824 |
+
print("GUI_INFO: name of new model: " + self.training_params["model_name"])
|
| 1825 |
+
self.new_model_path, train_losses = train.train_seg(
|
| 1826 |
+
self.model.net, train_data=self.train_data, train_labels=self.train_labels,
|
| 1827 |
+
normalize=normalize_params, min_train_masks=0,
|
| 1828 |
+
save_path=save_path, nimg_per_epoch=max(2, len(self.train_data)),
|
| 1829 |
+
learning_rate=self.training_params["learning_rate"],
|
| 1830 |
+
weight_decay=self.training_params["weight_decay"],
|
| 1831 |
+
n_epochs=self.training_params["n_epochs"],
|
| 1832 |
+
model_name=self.training_params["model_name"])[:2]
|
| 1833 |
+
# save train losses
|
| 1834 |
+
np.save(str(self.new_model_path) + "_train_losses.npy", train_losses)
|
| 1835 |
+
# run model on next image
|
| 1836 |
+
io._add_model(self, self.new_model_path)
|
| 1837 |
+
diam_labels = self.model.net.diam_labels.item() #.copy()
|
| 1838 |
+
self.new_model_ind = len(self.model_strings)
|
| 1839 |
+
self.autorun = True
|
| 1840 |
+
self.clear_all()
|
| 1841 |
+
self.restore = restore
|
| 1842 |
+
self.set_normalize_params(normalize_params)
|
| 1843 |
+
self.get_next_image(load_seg=False)
|
| 1844 |
+
|
| 1845 |
+
self.compute_segmentation(custom=True)
|
| 1846 |
+
self.logger.info(
|
| 1847 |
+
f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!"
|
| 1848 |
+
)
|
| 1849 |
+
|
| 1850 |
+
|
| 1851 |
+
def compute_cprob(self):
|
| 1852 |
+
if self.recompute_masks:
|
| 1853 |
+
flow_threshold = self.segmentation_settings.flow_threshold
|
| 1854 |
+
cellprob_threshold = self.segmentation_settings.cellprob_threshold
|
| 1855 |
+
niter = self.segmentation_settings.niter
|
| 1856 |
+
min_size = int(self.min_size.text()) if not isinstance(
|
| 1857 |
+
self.min_size, int) else self.min_size
|
| 1858 |
+
|
| 1859 |
+
self.logger.info(
|
| 1860 |
+
"computing masks with cell prob=%0.3f, flow error threshold=%0.3f" %
|
| 1861 |
+
(cellprob_threshold, flow_threshold))
|
| 1862 |
+
|
| 1863 |
+
try:
|
| 1864 |
+
dP = self.flows[2].squeeze()
|
| 1865 |
+
cellprob = self.flows[3].squeeze()
|
| 1866 |
+
except IndexError:
|
| 1867 |
+
self.logger.error("Flows don't exist, try running model again.")
|
| 1868 |
+
return
|
| 1869 |
+
|
| 1870 |
+
maski = dynamics.resize_and_compute_masks(
|
| 1871 |
+
dP=dP,
|
| 1872 |
+
cellprob=cellprob,
|
| 1873 |
+
niter=niter,
|
| 1874 |
+
do_3D=self.load_3D,
|
| 1875 |
+
min_size=min_size,
|
| 1876 |
+
# max_size_fraction=min_size_fraction, # Leave as default
|
| 1877 |
+
cellprob_threshold=cellprob_threshold,
|
| 1878 |
+
flow_threshold=flow_threshold)
|
| 1879 |
+
|
| 1880 |
+
self.masksOn = True
|
| 1881 |
+
if not self.OCheckBox.isChecked():
|
| 1882 |
+
self.MCheckBox.setChecked(True)
|
| 1883 |
+
if maski.ndim < 3:
|
| 1884 |
+
maski = maski[np.newaxis, ...]
|
| 1885 |
+
self.logger.info("%d cells found" % (len(np.unique(maski)[1:])))
|
| 1886 |
+
io._masks_to_gui(self, maski, outlines=None)
|
| 1887 |
+
self.show()
|
| 1888 |
+
|
| 1889 |
+
|
| 1890 |
+
def compute_segmentation(self, custom=False, model_name=None, load_model=True):
|
| 1891 |
+
self.progress.setValue(0)
|
| 1892 |
+
try:
|
| 1893 |
+
tic = time.time()
|
| 1894 |
+
self.clear_all()
|
| 1895 |
+
self.flows = [[], [], []]
|
| 1896 |
+
if load_model:
|
| 1897 |
+
self.initialize_model(model_name=model_name, custom=custom)
|
| 1898 |
+
self.progress.setValue(10)
|
| 1899 |
+
do_3D = self.load_3D
|
| 1900 |
+
stitch_threshold = float(self.stitch_threshold.text()) if not isinstance(
|
| 1901 |
+
self.stitch_threshold, float) else self.stitch_threshold
|
| 1902 |
+
anisotropy = float(self.anisotropy.text()) if not isinstance(
|
| 1903 |
+
self.anisotropy, float) else self.anisotropy
|
| 1904 |
+
flow3D_smooth = float(self.flow3D_smooth.text()) if not isinstance(
|
| 1905 |
+
self.flow3D_smooth, float) else self.flow3D_smooth
|
| 1906 |
+
min_size = int(self.min_size.text()) if not isinstance(
|
| 1907 |
+
self.min_size, int) else self.min_size
|
| 1908 |
+
|
| 1909 |
+
do_3D = False if stitch_threshold > 0. else do_3D
|
| 1910 |
+
|
| 1911 |
+
if self.restore == "filter":
|
| 1912 |
+
data = self.stack_filtered.copy().squeeze()
|
| 1913 |
+
else:
|
| 1914 |
+
data = self.stack.copy().squeeze()
|
| 1915 |
+
|
| 1916 |
+
flow_threshold = self.segmentation_settings.flow_threshold
|
| 1917 |
+
cellprob_threshold = self.segmentation_settings.cellprob_threshold
|
| 1918 |
+
diameter = self.segmentation_settings.diameter
|
| 1919 |
+
niter = self.segmentation_settings.niter
|
| 1920 |
+
|
| 1921 |
+
normalize_params = self.get_normalize_params()
|
| 1922 |
+
print(normalize_params)
|
| 1923 |
+
try:
|
| 1924 |
+
masks, flows = self.model.eval(
|
| 1925 |
+
data,
|
| 1926 |
+
diameter=diameter,
|
| 1927 |
+
cellprob_threshold=cellprob_threshold,
|
| 1928 |
+
flow_threshold=flow_threshold, do_3D=do_3D, niter=niter,
|
| 1929 |
+
normalize=normalize_params, stitch_threshold=stitch_threshold,
|
| 1930 |
+
anisotropy=anisotropy, flow3D_smooth=flow3D_smooth,
|
| 1931 |
+
min_size=min_size, channel_axis=-1,
|
| 1932 |
+
progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2]
|
| 1933 |
+
except Exception as e:
|
| 1934 |
+
print("NET ERROR: %s" % e)
|
| 1935 |
+
self.progress.setValue(0)
|
| 1936 |
+
return
|
| 1937 |
+
|
| 1938 |
+
self.progress.setValue(75)
|
| 1939 |
+
|
| 1940 |
+
# convert flows to uint8 and resize to original image size
|
| 1941 |
+
flows_new = []
|
| 1942 |
+
flows_new.append(flows[0].copy()) # RGB flow
|
| 1943 |
+
flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) *
|
| 1944 |
+
255).astype("uint8")) # cellprob
|
| 1945 |
+
flows_new.append(flows[1].copy()) # XY flows
|
| 1946 |
+
flows_new.append(flows[2].copy()) # original cellprob
|
| 1947 |
+
|
| 1948 |
+
if self.load_3D:
|
| 1949 |
+
if stitch_threshold == 0.:
|
| 1950 |
+
flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8"))
|
| 1951 |
+
else:
|
| 1952 |
+
flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8"))
|
| 1953 |
+
|
| 1954 |
+
if not self.load_3D:
|
| 1955 |
+
if self.restore and "upsample" in self.restore:
|
| 1956 |
+
self.Ly, self.Lx = self.Lyr, self.Lxr
|
| 1957 |
+
|
| 1958 |
+
if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx):
|
| 1959 |
+
self.flows = []
|
| 1960 |
+
for j in range(len(flows_new)):
|
| 1961 |
+
self.flows.append(
|
| 1962 |
+
resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx,
|
| 1963 |
+
interpolation=cv2.INTER_NEAREST))
|
| 1964 |
+
else:
|
| 1965 |
+
self.flows = flows_new
|
| 1966 |
+
else:
|
| 1967 |
+
self.flows = []
|
| 1968 |
+
Lz, Ly, Lx = self.NZ, self.Ly, self.Lx
|
| 1969 |
+
Lz0, Ly0, Lx0 = flows_new[0].shape[:3]
|
| 1970 |
+
print("GUI_INFO: resizing flows to original image size")
|
| 1971 |
+
for j in range(len(flows_new)):
|
| 1972 |
+
flow0 = flows_new[j]
|
| 1973 |
+
if Ly0 != Ly:
|
| 1974 |
+
flow0 = resize_image(flow0, Ly=Ly, Lx=Lx,
|
| 1975 |
+
no_channels=flow0.ndim==3,
|
| 1976 |
+
interpolation=cv2.INTER_NEAREST)
|
| 1977 |
+
if Lz0 != Lz:
|
| 1978 |
+
flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1),
|
| 1979 |
+
Ly=Lz, Lx=Lx,
|
| 1980 |
+
no_channels=flow0.ndim==3,
|
| 1981 |
+
interpolation=cv2.INTER_NEAREST), 0, 1)
|
| 1982 |
+
self.flows.append(flow0)
|
| 1983 |
+
|
| 1984 |
+
# add first axis
|
| 1985 |
+
if self.NZ == 1:
|
| 1986 |
+
masks = masks[np.newaxis, ...]
|
| 1987 |
+
self.flows = [
|
| 1988 |
+
self.flows[n][np.newaxis, ...] for n in range(len(self.flows))
|
| 1989 |
+
]
|
| 1990 |
+
|
| 1991 |
+
self.logger.info("%d cells found with model in %0.3f sec" %
|
| 1992 |
+
(len(np.unique(masks)[1:]), time.time() - tic))
|
| 1993 |
+
self.progress.setValue(80)
|
| 1994 |
+
z = 0
|
| 1995 |
+
|
| 1996 |
+
io._masks_to_gui(self, masks, outlines=None)
|
| 1997 |
+
self.masksOn = True
|
| 1998 |
+
self.MCheckBox.setChecked(True)
|
| 1999 |
+
self.progress.setValue(100)
|
| 2000 |
+
if self.restore != "filter" and self.restore is not None and self.autobtn.isChecked():
|
| 2001 |
+
self.compute_saturation()
|
| 2002 |
+
if not do_3D and not stitch_threshold > 0:
|
| 2003 |
+
self.recompute_masks = True
|
| 2004 |
+
else:
|
| 2005 |
+
self.recompute_masks = False
|
| 2006 |
+
except Exception as e:
|
| 2007 |
+
print("ERROR: %s" % e)
|
models/seg_post_model/cellpose/gui/gui3d.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import sys, pathlib, warnings
|
| 6 |
+
|
| 7 |
+
from qtpy import QtGui, QtCore
|
| 8 |
+
from qtpy.QtWidgets import QApplication, QScrollBar, QCheckBox, QLabel, QLineEdit
|
| 9 |
+
import pyqtgraph as pg
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
from scipy.stats import mode
|
| 13 |
+
import cv2
|
| 14 |
+
|
| 15 |
+
from . import guiparts, io
|
| 16 |
+
from ..utils import download_url_to_file, masks_to_outlines
|
| 17 |
+
from .gui import MainW
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
import matplotlib.pyplot as plt
|
| 21 |
+
MATPLOTLIB = True
|
| 22 |
+
except:
|
| 23 |
+
MATPLOTLIB = False
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def avg3d(C):
|
| 27 |
+
""" smooth value of c across nearby points
|
| 28 |
+
(c is center of grid directly below point)
|
| 29 |
+
b -- a -- b
|
| 30 |
+
a -- c -- a
|
| 31 |
+
b -- a -- b
|
| 32 |
+
"""
|
| 33 |
+
Ly, Lx = C.shape
|
| 34 |
+
# pad T by 2
|
| 35 |
+
T = np.zeros((Ly + 2, Lx + 2), "float32")
|
| 36 |
+
M = np.zeros((Ly, Lx), "float32")
|
| 37 |
+
T[1:-1, 1:-1] = C.copy()
|
| 38 |
+
y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
|
| 39 |
+
indexing="ij")
|
| 40 |
+
y += 1
|
| 41 |
+
x += 1
|
| 42 |
+
a = 1. / 2 #/(z**2 + 1)**0.5
|
| 43 |
+
b = 1. / (1 + 2**0.5) #(z**2 + 2)**0.5
|
| 44 |
+
c = 1.
|
| 45 |
+
M = (b * T[y - 1, x - 1] + a * T[y - 1, x] + b * T[y - 1, x + 1] + a * T[y, x - 1] +
|
| 46 |
+
c * T[y, x] + a * T[y, x + 1] + b * T[y + 1, x - 1] + a * T[y + 1, x] +
|
| 47 |
+
b * T[y + 1, x + 1])
|
| 48 |
+
M /= 4 * a + 4 * b + c
|
| 49 |
+
return M
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def interpZ(mask, zdraw):
|
| 53 |
+
""" find nearby planes and average their values using grid of points
|
| 54 |
+
zfill is in ascending order
|
| 55 |
+
"""
|
| 56 |
+
ifill = np.ones(mask.shape[0], "bool")
|
| 57 |
+
zall = np.arange(0, mask.shape[0], 1, int)
|
| 58 |
+
ifill[zdraw] = False
|
| 59 |
+
zfill = zall[ifill]
|
| 60 |
+
zlower = zdraw[np.searchsorted(zdraw, zfill, side="left") - 1]
|
| 61 |
+
zupper = zdraw[np.searchsorted(zdraw, zfill, side="right")]
|
| 62 |
+
for k, z in enumerate(zfill):
|
| 63 |
+
Z = zupper[k] - zlower[k]
|
| 64 |
+
zl = (z - zlower[k]) / Z
|
| 65 |
+
plower = avg3d(mask[zlower[k]]) * (1 - zl)
|
| 66 |
+
pupper = avg3d(mask[zupper[k]]) * zl
|
| 67 |
+
mask[z] = (plower + pupper) > 0.33
|
| 68 |
+
return mask, zfill
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def run(image=None):
|
| 72 |
+
from ..io import logger_setup
|
| 73 |
+
logger, log_file = logger_setup()
|
| 74 |
+
# Always start by initializing Qt (only once per application)
|
| 75 |
+
warnings.filterwarnings("ignore")
|
| 76 |
+
app = QApplication(sys.argv)
|
| 77 |
+
icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png")
|
| 78 |
+
guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png")
|
| 79 |
+
style_path = pathlib.Path.home().joinpath(".cellpose", "style_choice.npy")
|
| 80 |
+
if not icon_path.is_file():
|
| 81 |
+
cp_dir = pathlib.Path.home().joinpath(".cellpose")
|
| 82 |
+
cp_dir.mkdir(exist_ok=True)
|
| 83 |
+
print("downloading logo")
|
| 84 |
+
download_url_to_file(
|
| 85 |
+
"https://www.cellpose.org/static/images/cellpose_transparent.png",
|
| 86 |
+
icon_path, progress=True)
|
| 87 |
+
if not guip_path.is_file():
|
| 88 |
+
print("downloading help window image")
|
| 89 |
+
download_url_to_file("https://www.cellpose.org/static/images/cellpose_gui.png",
|
| 90 |
+
guip_path, progress=True)
|
| 91 |
+
icon_path = str(icon_path.resolve())
|
| 92 |
+
app_icon = QtGui.QIcon()
|
| 93 |
+
app_icon.addFile(icon_path, QtCore.QSize(16, 16))
|
| 94 |
+
app_icon.addFile(icon_path, QtCore.QSize(24, 24))
|
| 95 |
+
app_icon.addFile(icon_path, QtCore.QSize(32, 32))
|
| 96 |
+
app_icon.addFile(icon_path, QtCore.QSize(48, 48))
|
| 97 |
+
app_icon.addFile(icon_path, QtCore.QSize(64, 64))
|
| 98 |
+
app_icon.addFile(icon_path, QtCore.QSize(256, 256))
|
| 99 |
+
app.setWindowIcon(app_icon)
|
| 100 |
+
app.setStyle("Fusion")
|
| 101 |
+
app.setPalette(guiparts.DarkPalette())
|
| 102 |
+
MainW_3d(image=image, logger=logger)
|
| 103 |
+
ret = app.exec_()
|
| 104 |
+
sys.exit(ret)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
class MainW_3d(MainW):
|
| 108 |
+
|
| 109 |
+
def __init__(self, image=None, logger=None):
|
| 110 |
+
# MainW init
|
| 111 |
+
MainW.__init__(self, image=image, logger=logger)
|
| 112 |
+
|
| 113 |
+
# add gradZ view
|
| 114 |
+
self.ViewDropDown.insertItem(3, "gradZ")
|
| 115 |
+
|
| 116 |
+
# turn off single stroke
|
| 117 |
+
self.SCheckBox.setChecked(False)
|
| 118 |
+
|
| 119 |
+
### add orthoviews and z-bar
|
| 120 |
+
# ortho crosshair lines
|
| 121 |
+
self.vLine = pg.InfiniteLine(angle=90, movable=False)
|
| 122 |
+
self.hLine = pg.InfiniteLine(angle=0, movable=False)
|
| 123 |
+
self.vLineOrtho = [
|
| 124 |
+
pg.InfiniteLine(angle=90, movable=False),
|
| 125 |
+
pg.InfiniteLine(angle=90, movable=False)
|
| 126 |
+
]
|
| 127 |
+
self.hLineOrtho = [
|
| 128 |
+
pg.InfiniteLine(angle=0, movable=False),
|
| 129 |
+
pg.InfiniteLine(angle=0, movable=False)
|
| 130 |
+
]
|
| 131 |
+
self.make_orthoviews()
|
| 132 |
+
|
| 133 |
+
# z scrollbar underneath
|
| 134 |
+
self.scroll = QScrollBar(QtCore.Qt.Horizontal)
|
| 135 |
+
self.scroll.setMaximum(10)
|
| 136 |
+
self.scroll.valueChanged.connect(self.move_in_Z)
|
| 137 |
+
self.lmain.addWidget(self.scroll, 40, 9, 1, 30)
|
| 138 |
+
|
| 139 |
+
b = 22
|
| 140 |
+
|
| 141 |
+
label = QLabel("stitch\nthreshold:")
|
| 142 |
+
label.setToolTip(
|
| 143 |
+
"for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
|
| 144 |
+
)
|
| 145 |
+
label.setFont(self.medfont)
|
| 146 |
+
self.segBoxG.addWidget(label, b, 0, 1, 4)
|
| 147 |
+
self.stitch_threshold = QLineEdit()
|
| 148 |
+
self.stitch_threshold.setText("0.0")
|
| 149 |
+
self.stitch_threshold.setFixedWidth(30)
|
| 150 |
+
self.stitch_threshold.setFont(self.medfont)
|
| 151 |
+
self.stitch_threshold.setToolTip(
|
| 152 |
+
"for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)"
|
| 153 |
+
)
|
| 154 |
+
self.segBoxG.addWidget(self.stitch_threshold, b, 3, 1, 1)
|
| 155 |
+
|
| 156 |
+
label = QLabel("flow3D\nsmooth:")
|
| 157 |
+
label.setToolTip(
|
| 158 |
+
"for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)"
|
| 159 |
+
)
|
| 160 |
+
label.setFont(self.medfont)
|
| 161 |
+
self.segBoxG.addWidget(label, b, 4, 1, 3)
|
| 162 |
+
self.flow3D_smooth = QLineEdit()
|
| 163 |
+
self.flow3D_smooth.setText("0.0")
|
| 164 |
+
self.flow3D_smooth.setFixedWidth(30)
|
| 165 |
+
self.flow3D_smooth.setFont(self.medfont)
|
| 166 |
+
self.flow3D_smooth.setToolTip(
|
| 167 |
+
"for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)"
|
| 168 |
+
)
|
| 169 |
+
self.segBoxG.addWidget(self.flow3D_smooth, b, 7, 1, 1)
|
| 170 |
+
|
| 171 |
+
b+=1
|
| 172 |
+
label = QLabel("anisotropy:")
|
| 173 |
+
label.setToolTip(
|
| 174 |
+
"for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
|
| 175 |
+
)
|
| 176 |
+
label.setFont(self.medfont)
|
| 177 |
+
self.segBoxG.addWidget(label, b, 0, 1, 3)
|
| 178 |
+
self.anisotropy = QLineEdit()
|
| 179 |
+
self.anisotropy.setText("1.0")
|
| 180 |
+
self.anisotropy.setFixedWidth(30)
|
| 181 |
+
self.anisotropy.setFont(self.medfont)
|
| 182 |
+
self.anisotropy.setToolTip(
|
| 183 |
+
"for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)"
|
| 184 |
+
)
|
| 185 |
+
self.segBoxG.addWidget(self.anisotropy, b, 3, 1, 1)
|
| 186 |
+
|
| 187 |
+
b+=1
|
| 188 |
+
label = QLabel("min\nsize:")
|
| 189 |
+
label.setToolTip(
|
| 190 |
+
"all masks less than this size in pixels (volume) will be removed"
|
| 191 |
+
)
|
| 192 |
+
label.setFont(self.medfont)
|
| 193 |
+
self.segBoxG.addWidget(label, b, 0, 1, 4)
|
| 194 |
+
self.min_size = QLineEdit()
|
| 195 |
+
self.min_size.setText("15")
|
| 196 |
+
self.min_size.setFixedWidth(50)
|
| 197 |
+
self.min_size.setFont(self.medfont)
|
| 198 |
+
self.min_size.setToolTip(
|
| 199 |
+
"all masks less than this size in pixels (volume) will be removed"
|
| 200 |
+
)
|
| 201 |
+
self.segBoxG.addWidget(self.min_size, b, 3, 1, 1)
|
| 202 |
+
|
| 203 |
+
b += 1
|
| 204 |
+
self.orthobtn = QCheckBox("ortho")
|
| 205 |
+
self.orthobtn.setToolTip("activate orthoviews with 3D image")
|
| 206 |
+
self.orthobtn.setFont(self.medfont)
|
| 207 |
+
self.orthobtn.setChecked(False)
|
| 208 |
+
self.l0.addWidget(self.orthobtn, b, 0, 1, 2)
|
| 209 |
+
self.orthobtn.toggled.connect(self.toggle_ortho)
|
| 210 |
+
|
| 211 |
+
label = QLabel("dz:")
|
| 212 |
+
label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 213 |
+
label.setFont(self.medfont)
|
| 214 |
+
self.l0.addWidget(label, b, 2, 1, 1)
|
| 215 |
+
self.dz = 10
|
| 216 |
+
self.dzedit = QLineEdit()
|
| 217 |
+
self.dzedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 218 |
+
self.dzedit.setText(str(self.dz))
|
| 219 |
+
self.dzedit.returnPressed.connect(self.update_ortho)
|
| 220 |
+
self.dzedit.setFixedWidth(40)
|
| 221 |
+
self.dzedit.setFont(self.medfont)
|
| 222 |
+
self.l0.addWidget(self.dzedit, b, 3, 1, 2)
|
| 223 |
+
|
| 224 |
+
label = QLabel("z-aspect:")
|
| 225 |
+
label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 226 |
+
label.setFont(self.medfont)
|
| 227 |
+
self.l0.addWidget(label, b, 5, 1, 2)
|
| 228 |
+
self.zaspect = 1.0
|
| 229 |
+
self.zaspectedit = QLineEdit()
|
| 230 |
+
self.zaspectedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 231 |
+
self.zaspectedit.setText(str(self.zaspect))
|
| 232 |
+
self.zaspectedit.returnPressed.connect(self.update_ortho)
|
| 233 |
+
self.zaspectedit.setFixedWidth(40)
|
| 234 |
+
self.zaspectedit.setFont(self.medfont)
|
| 235 |
+
self.l0.addWidget(self.zaspectedit, b, 7, 1, 2)
|
| 236 |
+
|
| 237 |
+
b += 1
|
| 238 |
+
# add z position underneath
|
| 239 |
+
self.currentZ = 0
|
| 240 |
+
label = QLabel("Z:")
|
| 241 |
+
label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 242 |
+
self.l0.addWidget(label, b, 5, 1, 2)
|
| 243 |
+
self.zpos = QLineEdit()
|
| 244 |
+
self.zpos.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 245 |
+
self.zpos.setText(str(self.currentZ))
|
| 246 |
+
self.zpos.returnPressed.connect(self.update_ztext)
|
| 247 |
+
self.zpos.setFixedWidth(40)
|
| 248 |
+
self.zpos.setFont(self.medfont)
|
| 249 |
+
self.l0.addWidget(self.zpos, b, 7, 1, 2)
|
| 250 |
+
|
| 251 |
+
# if called with image, load it
|
| 252 |
+
if image is not None:
|
| 253 |
+
self.filename = image
|
| 254 |
+
io._load_image(self, self.filename, load_3D=True)
|
| 255 |
+
|
| 256 |
+
self.load_3D = True
|
| 257 |
+
|
| 258 |
+
def add_mask(self, points=None, color=(100, 200, 50), dense=True):
|
| 259 |
+
# points is list of strokes
|
| 260 |
+
|
| 261 |
+
points_all = np.concatenate(points, axis=0)
|
| 262 |
+
|
| 263 |
+
# loop over z values
|
| 264 |
+
median = []
|
| 265 |
+
zdraw = np.unique(points_all[:, 0])
|
| 266 |
+
zrange = np.arange(zdraw.min(), zdraw.max() + 1, 1, int)
|
| 267 |
+
zmin = zdraw.min()
|
| 268 |
+
pix = np.zeros((2, 0), "uint16")
|
| 269 |
+
mall = np.zeros((len(zrange), self.Ly, self.Lx), "bool")
|
| 270 |
+
k = 0
|
| 271 |
+
for z in zdraw:
|
| 272 |
+
ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros(
|
| 273 |
+
0, "int"), np.zeros(0, "int")
|
| 274 |
+
for stroke in points:
|
| 275 |
+
stroke = np.concatenate(stroke, axis=0).reshape(-1, 4)
|
| 276 |
+
iz = stroke[:, 0] == z
|
| 277 |
+
vr = stroke[iz, 1]
|
| 278 |
+
vc = stroke[iz, 2]
|
| 279 |
+
if iz.sum() > 0:
|
| 280 |
+
# get points inside drawn points
|
| 281 |
+
mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8")
|
| 282 |
+
pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2),
|
| 283 |
+
axis=-1)[:, np.newaxis, :]
|
| 284 |
+
mask = cv2.fillPoly(mask, [pts], (255, 0, 0))
|
| 285 |
+
ar, ac = np.nonzero(mask)
|
| 286 |
+
ar, ac = ar + vr.min() - 2, ac + vc.min() - 2
|
| 287 |
+
# get dense outline
|
| 288 |
+
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 289 |
+
cv2.CHAIN_APPROX_NONE)
|
| 290 |
+
pvc, pvr = contours[-2][0].squeeze().T
|
| 291 |
+
vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2
|
| 292 |
+
# concatenate all points
|
| 293 |
+
ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac))))
|
| 294 |
+
# if these pixels are overlapping with another cell, reassign them
|
| 295 |
+
ioverlap = self.cellpix[z][ar, ac] > 0
|
| 296 |
+
if (~ioverlap).sum() < 8:
|
| 297 |
+
print("ERROR: cell too small without overlaps, not drawn")
|
| 298 |
+
return None
|
| 299 |
+
elif ioverlap.sum() > 0:
|
| 300 |
+
ar, ac = ar[~ioverlap], ac[~ioverlap]
|
| 301 |
+
# compute outline of new mask
|
| 302 |
+
mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8")
|
| 303 |
+
mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1
|
| 304 |
+
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 305 |
+
cv2.CHAIN_APPROX_NONE)
|
| 306 |
+
pvc, pvr = contours[-2][0].squeeze().T
|
| 307 |
+
vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2
|
| 308 |
+
ars = np.concatenate((ars, ar), axis=0)
|
| 309 |
+
acs = np.concatenate((acs, ac), axis=0)
|
| 310 |
+
vrs = np.concatenate((vrs, vr), axis=0)
|
| 311 |
+
vcs = np.concatenate((vcs, vc), axis=0)
|
| 312 |
+
self.draw_mask(z, ars, acs, vrs, vcs, color)
|
| 313 |
+
|
| 314 |
+
median.append(np.array([np.median(ars), np.median(acs)]))
|
| 315 |
+
mall[z - zmin, ars, acs] = True
|
| 316 |
+
pix = np.append(pix, np.vstack((ars, acs)), axis=-1)
|
| 317 |
+
|
| 318 |
+
mall = mall[:, pix[0].min():pix[0].max() + 1,
|
| 319 |
+
pix[1].min():pix[1].max() + 1].astype("float32")
|
| 320 |
+
ymin, xmin = pix[0].min(), pix[1].min()
|
| 321 |
+
if len(zdraw) > 1:
|
| 322 |
+
mall, zfill = interpZ(mall, zdraw - zmin)
|
| 323 |
+
for z in zfill:
|
| 324 |
+
mask = mall[z].copy()
|
| 325 |
+
ar, ac = np.nonzero(mask)
|
| 326 |
+
ioverlap = self.cellpix[z + zmin][ar + ymin, ac + xmin] > 0
|
| 327 |
+
if (~ioverlap).sum() < 5:
|
| 328 |
+
print("WARNING: stroke on plane %d not included due to overlaps" %
|
| 329 |
+
z)
|
| 330 |
+
elif ioverlap.sum() > 0:
|
| 331 |
+
mask[ar[ioverlap], ac[ioverlap]] = 0
|
| 332 |
+
ar, ac = ar[~ioverlap], ac[~ioverlap]
|
| 333 |
+
# compute outline of mask
|
| 334 |
+
outlines = masks_to_outlines(mask)
|
| 335 |
+
vr, vc = np.nonzero(outlines)
|
| 336 |
+
vr, vc = vr + ymin, vc + xmin
|
| 337 |
+
ar, ac = ar + ymin, ac + xmin
|
| 338 |
+
self.draw_mask(z + zmin, ar, ac, vr, vc, color)
|
| 339 |
+
|
| 340 |
+
self.zdraw.append(zdraw)
|
| 341 |
+
|
| 342 |
+
return median
|
| 343 |
+
|
| 344 |
+
def move_in_Z(self):
|
| 345 |
+
if self.loaded:
|
| 346 |
+
self.currentZ = min(self.NZ, max(0, int(self.scroll.value())))
|
| 347 |
+
self.zpos.setText(str(self.currentZ))
|
| 348 |
+
self.update_plot()
|
| 349 |
+
self.draw_layer()
|
| 350 |
+
self.update_layer()
|
| 351 |
+
|
| 352 |
+
def make_orthoviews(self):
|
| 353 |
+
self.pOrtho, self.imgOrtho, self.layerOrtho = [], [], []
|
| 354 |
+
for j in range(2):
|
| 355 |
+
self.pOrtho.append(
|
| 356 |
+
pg.ViewBox(lockAspect=True, name=f"plotOrtho{j}",
|
| 357 |
+
border=[100, 100, 100], invertY=True, enableMouse=False))
|
| 358 |
+
self.pOrtho[j].setMenuEnabled(False)
|
| 359 |
+
|
| 360 |
+
self.imgOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self))
|
| 361 |
+
self.imgOrtho[j].autoDownsample = False
|
| 362 |
+
|
| 363 |
+
self.layerOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self))
|
| 364 |
+
self.layerOrtho[j].setLevels([0., 255.])
|
| 365 |
+
|
| 366 |
+
#self.pOrtho[j].scene().contextMenuItem = self.pOrtho[j]
|
| 367 |
+
self.pOrtho[j].addItem(self.imgOrtho[j])
|
| 368 |
+
self.pOrtho[j].addItem(self.layerOrtho[j])
|
| 369 |
+
self.pOrtho[j].addItem(self.vLineOrtho[j], ignoreBounds=False)
|
| 370 |
+
self.pOrtho[j].addItem(self.hLineOrtho[j], ignoreBounds=False)
|
| 371 |
+
|
| 372 |
+
self.pOrtho[0].linkView(self.pOrtho[0].YAxis, self.p0)
|
| 373 |
+
self.pOrtho[1].linkView(self.pOrtho[1].XAxis, self.p0)
|
| 374 |
+
|
| 375 |
+
def add_orthoviews(self):
|
| 376 |
+
self.yortho = self.Ly // 2
|
| 377 |
+
self.xortho = self.Lx // 2
|
| 378 |
+
if self.NZ > 1:
|
| 379 |
+
self.update_ortho()
|
| 380 |
+
|
| 381 |
+
self.win.addItem(self.pOrtho[0], 0, 1, rowspan=1, colspan=1)
|
| 382 |
+
self.win.addItem(self.pOrtho[1], 1, 0, rowspan=1, colspan=1)
|
| 383 |
+
|
| 384 |
+
qGraphicsGridLayout = self.win.ci.layout
|
| 385 |
+
qGraphicsGridLayout.setColumnStretchFactor(0, 2)
|
| 386 |
+
qGraphicsGridLayout.setColumnStretchFactor(1, 1)
|
| 387 |
+
qGraphicsGridLayout.setRowStretchFactor(0, 2)
|
| 388 |
+
qGraphicsGridLayout.setRowStretchFactor(1, 1)
|
| 389 |
+
|
| 390 |
+
self.pOrtho[0].setYRange(0, self.Lx)
|
| 391 |
+
self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
|
| 392 |
+
self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
|
| 393 |
+
self.pOrtho[1].setXRange(0, self.Ly)
|
| 394 |
+
|
| 395 |
+
self.p0.addItem(self.vLine, ignoreBounds=False)
|
| 396 |
+
self.p0.addItem(self.hLine, ignoreBounds=False)
|
| 397 |
+
self.p0.setYRange(0, self.Lx)
|
| 398 |
+
self.p0.setXRange(0, self.Ly)
|
| 399 |
+
|
| 400 |
+
self.win.show()
|
| 401 |
+
self.show()
|
| 402 |
+
|
| 403 |
+
def remove_orthoviews(self):
|
| 404 |
+
self.win.removeItem(self.pOrtho[0])
|
| 405 |
+
self.win.removeItem(self.pOrtho[1])
|
| 406 |
+
self.p0.removeItem(self.vLine)
|
| 407 |
+
self.p0.removeItem(self.hLine)
|
| 408 |
+
self.win.show()
|
| 409 |
+
self.show()
|
| 410 |
+
|
| 411 |
+
def update_crosshairs(self):
|
| 412 |
+
self.yortho = min(self.Ly - 1, max(0, int(self.yortho)))
|
| 413 |
+
self.xortho = min(self.Lx - 1, max(0, int(self.xortho)))
|
| 414 |
+
self.vLine.setPos(self.xortho)
|
| 415 |
+
self.hLine.setPos(self.yortho)
|
| 416 |
+
self.vLineOrtho[1].setPos(self.xortho)
|
| 417 |
+
self.hLineOrtho[1].setPos(self.zc)
|
| 418 |
+
self.vLineOrtho[0].setPos(self.zc)
|
| 419 |
+
self.hLineOrtho[0].setPos(self.yortho)
|
| 420 |
+
|
| 421 |
+
def update_ortho(self):
|
| 422 |
+
if self.NZ > 1 and self.orthobtn.isChecked():
|
| 423 |
+
dzcurrent = self.dz
|
| 424 |
+
self.dz = min(100, max(3, int(self.dzedit.text())))
|
| 425 |
+
self.zaspect = max(0.01, min(100., float(self.zaspectedit.text())))
|
| 426 |
+
self.dzedit.setText(str(self.dz))
|
| 427 |
+
self.zaspectedit.setText(str(self.zaspect))
|
| 428 |
+
if self.dz != dzcurrent:
|
| 429 |
+
self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
|
| 430 |
+
self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3)
|
| 431 |
+
dztot = min(self.NZ, self.dz * 2)
|
| 432 |
+
y = self.yortho
|
| 433 |
+
x = self.xortho
|
| 434 |
+
z = self.currentZ
|
| 435 |
+
if dztot == self.NZ:
|
| 436 |
+
zmin, zmax = 0, self.NZ
|
| 437 |
+
else:
|
| 438 |
+
if z - self.dz < 0:
|
| 439 |
+
zmin = 0
|
| 440 |
+
zmax = zmin + self.dz * 2
|
| 441 |
+
elif z + self.dz >= self.NZ:
|
| 442 |
+
zmax = self.NZ
|
| 443 |
+
zmin = zmax - self.dz * 2
|
| 444 |
+
else:
|
| 445 |
+
zmin, zmax = z - self.dz, z + self.dz
|
| 446 |
+
self.zc = z - zmin
|
| 447 |
+
self.update_crosshairs()
|
| 448 |
+
if self.view == 0 or self.view == 4:
|
| 449 |
+
for j in range(2):
|
| 450 |
+
if j == 0:
|
| 451 |
+
if self.view == 0:
|
| 452 |
+
image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy()
|
| 453 |
+
else:
|
| 454 |
+
image = self.stack_filtered[zmin:zmax, :,
|
| 455 |
+
x].transpose(1, 0, 2).copy()
|
| 456 |
+
else:
|
| 457 |
+
image = self.stack[
|
| 458 |
+
zmin:zmax,
|
| 459 |
+
y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax,
|
| 460 |
+
y, :].copy()
|
| 461 |
+
if self.nchan == 1:
|
| 462 |
+
# show single channel
|
| 463 |
+
image = image[..., 0]
|
| 464 |
+
if self.color == 0:
|
| 465 |
+
self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
|
| 466 |
+
if self.nchan > 1:
|
| 467 |
+
levels = np.array([
|
| 468 |
+
self.saturation[0][self.currentZ],
|
| 469 |
+
self.saturation[1][self.currentZ],
|
| 470 |
+
self.saturation[2][self.currentZ]
|
| 471 |
+
])
|
| 472 |
+
self.imgOrtho[j].setLevels(levels)
|
| 473 |
+
else:
|
| 474 |
+
self.imgOrtho[j].setLevels(
|
| 475 |
+
self.saturation[0][self.currentZ])
|
| 476 |
+
elif self.color > 0 and self.color < 4:
|
| 477 |
+
if self.nchan > 1:
|
| 478 |
+
image = image[..., self.color - 1]
|
| 479 |
+
self.imgOrtho[j].setImage(image, autoLevels=False,
|
| 480 |
+
lut=self.cmap[self.color])
|
| 481 |
+
if self.nchan > 1:
|
| 482 |
+
self.imgOrtho[j].setLevels(
|
| 483 |
+
self.saturation[self.color - 1][self.currentZ])
|
| 484 |
+
else:
|
| 485 |
+
self.imgOrtho[j].setLevels(
|
| 486 |
+
self.saturation[0][self.currentZ])
|
| 487 |
+
elif self.color == 4:
|
| 488 |
+
if image.ndim > 2:
|
| 489 |
+
image = image.astype("float32").mean(axis=2).astype("uint8")
|
| 490 |
+
self.imgOrtho[j].setImage(image, autoLevels=False, lut=None)
|
| 491 |
+
self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
|
| 492 |
+
elif self.color == 5:
|
| 493 |
+
if image.ndim > 2:
|
| 494 |
+
image = image.astype("float32").mean(axis=2).astype("uint8")
|
| 495 |
+
self.imgOrtho[j].setImage(image, autoLevels=False,
|
| 496 |
+
lut=self.cmap[0])
|
| 497 |
+
self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ])
|
| 498 |
+
self.pOrtho[0].setAspectLocked(lock=True, ratio=self.zaspect)
|
| 499 |
+
self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect)
|
| 500 |
+
|
| 501 |
+
else:
|
| 502 |
+
image = np.zeros((10, 10), "uint8")
|
| 503 |
+
self.imgOrtho[0].setImage(image, autoLevels=False, lut=None)
|
| 504 |
+
self.imgOrtho[0].setLevels([0.0, 255.0])
|
| 505 |
+
self.imgOrtho[1].setImage(image, autoLevels=False, lut=None)
|
| 506 |
+
self.imgOrtho[1].setLevels([0.0, 255.0])
|
| 507 |
+
|
| 508 |
+
zrange = zmax - zmin
|
| 509 |
+
self.layer_ortho = [
|
| 510 |
+
np.zeros((self.Ly, zrange, 4), "uint8"),
|
| 511 |
+
np.zeros((zrange, self.Lx, 4), "uint8")
|
| 512 |
+
]
|
| 513 |
+
if self.masksOn:
|
| 514 |
+
for j in range(2):
|
| 515 |
+
if j == 0:
|
| 516 |
+
cp = self.cellpix[zmin:zmax, :, x].T
|
| 517 |
+
else:
|
| 518 |
+
cp = self.cellpix[zmin:zmax, y]
|
| 519 |
+
self.layer_ortho[j][..., :3] = self.cellcolors[cp, :]
|
| 520 |
+
self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8")
|
| 521 |
+
if self.selected > 0:
|
| 522 |
+
self.layer_ortho[j][cp == self.selected] = np.array(
|
| 523 |
+
[255, 255, 255, self.opacity])
|
| 524 |
+
|
| 525 |
+
if self.outlinesOn:
|
| 526 |
+
for j in range(2):
|
| 527 |
+
if j == 0:
|
| 528 |
+
op = self.outpix[zmin:zmax, :, x].T
|
| 529 |
+
else:
|
| 530 |
+
op = self.outpix[zmin:zmax, y]
|
| 531 |
+
self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8")
|
| 532 |
+
|
| 533 |
+
for j in range(2):
|
| 534 |
+
self.layerOrtho[j].setImage(self.layer_ortho[j])
|
| 535 |
+
self.win.show()
|
| 536 |
+
self.show()
|
| 537 |
+
|
| 538 |
+
def toggle_ortho(self):
|
| 539 |
+
if self.orthobtn.isChecked():
|
| 540 |
+
self.add_orthoviews()
|
| 541 |
+
else:
|
| 542 |
+
self.remove_orthoviews()
|
| 543 |
+
|
| 544 |
+
def plot_clicked(self, event):
|
| 545 |
+
if event.button()==QtCore.Qt.LeftButton \
|
| 546 |
+
and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\
|
| 547 |
+
and not self.removing_region:
|
| 548 |
+
if event.double():
|
| 549 |
+
try:
|
| 550 |
+
self.p0.setYRange(0, self.Ly + self.pr)
|
| 551 |
+
except:
|
| 552 |
+
self.p0.setYRange(0, self.Ly)
|
| 553 |
+
self.p0.setXRange(0, self.Lx)
|
| 554 |
+
elif self.loaded and not self.in_stroke:
|
| 555 |
+
if self.orthobtn.isChecked():
|
| 556 |
+
items = self.win.scene().items(event.scenePos())
|
| 557 |
+
for x in items:
|
| 558 |
+
if x == self.p0:
|
| 559 |
+
pos = self.p0.mapSceneToView(event.scenePos())
|
| 560 |
+
x = int(pos.x())
|
| 561 |
+
y = int(pos.y())
|
| 562 |
+
if y >= 0 and y < self.Ly and x >= 0 and x < self.Lx:
|
| 563 |
+
self.yortho = y
|
| 564 |
+
self.xortho = x
|
| 565 |
+
self.update_ortho()
|
| 566 |
+
|
| 567 |
+
def update_plot(self):
|
| 568 |
+
super().update_plot()
|
| 569 |
+
if self.NZ > 1 and self.orthobtn.isChecked():
|
| 570 |
+
self.update_ortho()
|
| 571 |
+
self.win.show()
|
| 572 |
+
self.show()
|
| 573 |
+
|
| 574 |
+
def keyPressEvent(self, event):
|
| 575 |
+
if self.loaded:
|
| 576 |
+
if not (event.modifiers() &
|
| 577 |
+
(QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier |
|
| 578 |
+
QtCore.Qt.AltModifier) or self.in_stroke):
|
| 579 |
+
updated = False
|
| 580 |
+
if len(self.current_point_set) > 0:
|
| 581 |
+
if event.key() == QtCore.Qt.Key_Return:
|
| 582 |
+
self.add_set()
|
| 583 |
+
if self.NZ > 1:
|
| 584 |
+
if event.key() == QtCore.Qt.Key_Left:
|
| 585 |
+
self.currentZ = max(0, self.currentZ - 1)
|
| 586 |
+
self.scroll.setValue(self.currentZ)
|
| 587 |
+
updated = True
|
| 588 |
+
elif event.key() == QtCore.Qt.Key_Right:
|
| 589 |
+
self.currentZ = min(self.NZ - 1, self.currentZ + 1)
|
| 590 |
+
self.scroll.setValue(self.currentZ)
|
| 591 |
+
updated = True
|
| 592 |
+
else:
|
| 593 |
+
nviews = self.ViewDropDown.count() - 1
|
| 594 |
+
nviews += int(
|
| 595 |
+
self.ViewDropDown.model().item(self.ViewDropDown.count() -
|
| 596 |
+
1).isEnabled())
|
| 597 |
+
if event.key() == QtCore.Qt.Key_X:
|
| 598 |
+
self.MCheckBox.toggle()
|
| 599 |
+
if event.key() == QtCore.Qt.Key_Z:
|
| 600 |
+
self.OCheckBox.toggle()
|
| 601 |
+
if event.key() == QtCore.Qt.Key_Left or event.key(
|
| 602 |
+
) == QtCore.Qt.Key_A:
|
| 603 |
+
self.currentZ = max(0, self.currentZ - 1)
|
| 604 |
+
self.scroll.setValue(self.currentZ)
|
| 605 |
+
updated = True
|
| 606 |
+
elif event.key() == QtCore.Qt.Key_Right or event.key(
|
| 607 |
+
) == QtCore.Qt.Key_D:
|
| 608 |
+
self.currentZ = min(self.NZ - 1, self.currentZ + 1)
|
| 609 |
+
self.scroll.setValue(self.currentZ)
|
| 610 |
+
updated = True
|
| 611 |
+
elif event.key() == QtCore.Qt.Key_PageDown:
|
| 612 |
+
self.view = (self.view + 1) % (nviews)
|
| 613 |
+
self.ViewDropDown.setCurrentIndex(self.view)
|
| 614 |
+
elif event.key() == QtCore.Qt.Key_PageUp:
|
| 615 |
+
self.view = (self.view - 1) % (nviews)
|
| 616 |
+
self.ViewDropDown.setCurrentIndex(self.view)
|
| 617 |
+
|
| 618 |
+
# can change background or stroke size if cell not finished
|
| 619 |
+
if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W:
|
| 620 |
+
self.color = (self.color - 1) % (6)
|
| 621 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 622 |
+
elif event.key() == QtCore.Qt.Key_Down or event.key(
|
| 623 |
+
) == QtCore.Qt.Key_S:
|
| 624 |
+
self.color = (self.color + 1) % (6)
|
| 625 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 626 |
+
elif event.key() == QtCore.Qt.Key_R:
|
| 627 |
+
if self.color != 1:
|
| 628 |
+
self.color = 1
|
| 629 |
+
else:
|
| 630 |
+
self.color = 0
|
| 631 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 632 |
+
elif event.key() == QtCore.Qt.Key_G:
|
| 633 |
+
if self.color != 2:
|
| 634 |
+
self.color = 2
|
| 635 |
+
else:
|
| 636 |
+
self.color = 0
|
| 637 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 638 |
+
elif event.key() == QtCore.Qt.Key_B:
|
| 639 |
+
if self.color != 3:
|
| 640 |
+
self.color = 3
|
| 641 |
+
else:
|
| 642 |
+
self.color = 0
|
| 643 |
+
self.RGBDropDown.setCurrentIndex(self.color)
|
| 644 |
+
elif (event.key() == QtCore.Qt.Key_Comma or
|
| 645 |
+
event.key() == QtCore.Qt.Key_Period):
|
| 646 |
+
count = self.BrushChoose.count()
|
| 647 |
+
gci = self.BrushChoose.currentIndex()
|
| 648 |
+
if event.key() == QtCore.Qt.Key_Comma:
|
| 649 |
+
gci = max(0, gci - 1)
|
| 650 |
+
else:
|
| 651 |
+
gci = min(count - 1, gci + 1)
|
| 652 |
+
self.BrushChoose.setCurrentIndex(gci)
|
| 653 |
+
self.brush_choose()
|
| 654 |
+
if not updated:
|
| 655 |
+
self.update_plot()
|
| 656 |
+
if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal:
|
| 657 |
+
self.p0.keyPressEvent(event)
|
| 658 |
+
|
| 659 |
+
def update_ztext(self):
|
| 660 |
+
zpos = self.currentZ
|
| 661 |
+
try:
|
| 662 |
+
zpos = int(self.zpos.text())
|
| 663 |
+
except:
|
| 664 |
+
print("ERROR: zposition is not a number")
|
| 665 |
+
self.currentZ = max(0, min(self.NZ - 1, zpos))
|
| 666 |
+
self.zpos.setText(str(self.currentZ))
|
| 667 |
+
self.scroll.setValue(self.currentZ)
|
models/seg_post_model/cellpose/gui/guihelpwindowtext.html
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<qt>
|
| 2 |
+
<p class="has-line-data" data-line-start="5" data-line-end="6">
|
| 3 |
+
<b>Main GUI mouse controls:</b>
|
| 4 |
+
</p>
|
| 5 |
+
<ul>
|
| 6 |
+
<li class="has-line-data" data-line-start="7" data-line-end="8">Pan = left-click + drag</li>
|
| 7 |
+
<li class="has-line-data" data-line-start="8" data-line-end="9">Zoom = scroll wheel (or +/= and - buttons)</li>
|
| 8 |
+
<li class="has-line-data" data-line-start="9" data-line-end="10">Full view = double left-click</li>
|
| 9 |
+
<li class="has-line-data" data-line-start="10" data-line-end="11">Select mask = left-click on mask</li>
|
| 10 |
+
<li class="has-line-data" data-line-start="11" data-line-end="12">Delete mask = Ctrl (or COMMAND on Mac) +
|
| 11 |
+
left-click
|
| 12 |
+
</li>
|
| 13 |
+
<li class="has-line-data" data-line-start="11" data-line-end="12">Merge masks = Alt + left-click (will merge
|
| 14 |
+
last two)
|
| 15 |
+
</li>
|
| 16 |
+
<li class="has-line-data" data-line-start="12" data-line-end="13">Start draw mask = right-click</li>
|
| 17 |
+
<li class="has-line-data" data-line-start="13" data-line-end="15">End draw mask = right-click, or return to
|
| 18 |
+
circle at beginning
|
| 19 |
+
</li>
|
| 20 |
+
</ul>
|
| 21 |
+
<p class="has-line-data" data-line-start="15" data-line-end="16">Overlaps in masks are NOT allowed. If you
|
| 22 |
+
draw a mask on top of another mask, it is cropped so that it doesn’t overlap with the old mask. Masks in 2D
|
| 23 |
+
should be single strokes (single stroke is checked). If you want to draw masks in 3D (experimental), then
|
| 24 |
+
you can turn this option off and draw a stroke on each plane with the cell and then press ENTER. 3D
|
| 25 |
+
labelling will fill in planes that you have not labelled so that you do not have to as densely label.
|
| 26 |
+
</p>
|
| 27 |
+
<p class="has-line-data" data-line-start="17" data-line-end="18"> <b>!NOTE!:</b> The GUI automatically saves after
|
| 28 |
+
you draw a mask in 2D but NOT after 3D mask drawing and NOT after segmentation. Save in the file menu or
|
| 29 |
+
with Ctrl+S. The output file is in the same folder as the loaded image with <code>_seg.npy</code> appended.
|
| 30 |
+
</p>
|
| 31 |
+
|
| 32 |
+
<p class="has-line-data" data-line-start="19" data-line-end="20"> <b>Bulk Mask Deletion</b>
|
| 33 |
+
Clicking the 'delete multiple' button will allow you to select and delete multiple masks at once.
|
| 34 |
+
Masks can be deselected by clicking on them again. Once you have selected all the masks you want to delete,
|
| 35 |
+
click the 'done' button to delete them.
|
| 36 |
+
<br>
|
| 37 |
+
<br>
|
| 38 |
+
Alternatively, you can create a rectangular region to delete a regions of masks by clicking the
|
| 39 |
+
'delete multiple' button, and then moving and/or resizing the region to select the masks you want to delete.
|
| 40 |
+
Once you have selected the masks you want to delete, click the 'done' button to delete them.
|
| 41 |
+
<br>
|
| 42 |
+
<br>
|
| 43 |
+
At any point in the process, you can click the 'cancel' button to cancel the bulk deletion.
|
| 44 |
+
</p>
|
| 45 |
+
<hr>
|
| 46 |
+
<table class="table table-striped table-bordered">
|
| 47 |
+
<br>
|
| 48 |
+
<br>
|
| 49 |
+
FYI there are tooltips throughout the GUI (hover over text to see)
|
| 50 |
+
<br>
|
| 51 |
+
<thead>
|
| 52 |
+
<tr>
|
| 53 |
+
<th>Keyboard shortcuts</th>
|
| 54 |
+
<th>Description</th>
|
| 55 |
+
</tr>
|
| 56 |
+
</thead>
|
| 57 |
+
<tbody>
|
| 58 |
+
<tr>
|
| 59 |
+
<td>=/+ button // - button</td>
|
| 60 |
+
<td>zoom in // zoom out</td>
|
| 61 |
+
</tr>
|
| 62 |
+
<tr>
|
| 63 |
+
<td>CTRL+Z</td>
|
| 64 |
+
<td>undo previously drawn mask/stroke</td>
|
| 65 |
+
</tr>
|
| 66 |
+
<tr>
|
| 67 |
+
<td>CTRL+Y</td>
|
| 68 |
+
<td>undo remove mask</td>
|
| 69 |
+
</tr>
|
| 70 |
+
<tr>
|
| 71 |
+
<td>CTRL+0</td>
|
| 72 |
+
<td>clear all masks</td>
|
| 73 |
+
</tr>
|
| 74 |
+
<tr>
|
| 75 |
+
<td>CTRL+L</td>
|
| 76 |
+
<td>load image (can alternatively drag and drop image)</td>
|
| 77 |
+
</tr>
|
| 78 |
+
<tr>
|
| 79 |
+
<td>CTRL+S</td>
|
| 80 |
+
<td>SAVE MASKS IN IMAGE to <code>_seg.npy</code> file</td>
|
| 81 |
+
</tr>
|
| 82 |
+
<tr>
|
| 83 |
+
<td>CTRL+T</td>
|
| 84 |
+
<td>train model using _seg.npy files in folder
|
| 85 |
+
</tr>
|
| 86 |
+
<tr>
|
| 87 |
+
<td>CTRL+P</td>
|
| 88 |
+
<td>load <code>_seg.npy</code> file (note: it will load automatically with image if it exists)</td>
|
| 89 |
+
</tr>
|
| 90 |
+
<tr>
|
| 91 |
+
<td>CTRL+M</td>
|
| 92 |
+
<td>load masks file (must be same size as image with 0 for NO mask, and 1,2,3… for masks)</td>
|
| 93 |
+
</tr>
|
| 94 |
+
<tr>
|
| 95 |
+
<td>CTRL+N</td>
|
| 96 |
+
<td>save masks as PNG</td>
|
| 97 |
+
</tr>
|
| 98 |
+
<tr>
|
| 99 |
+
<td>CTRL+R</td>
|
| 100 |
+
<td>save ROIs to native ImageJ ROI format</td>
|
| 101 |
+
</tr>
|
| 102 |
+
<tr>
|
| 103 |
+
<td>CTRL+F</td>
|
| 104 |
+
<td>save flows to image file</td>
|
| 105 |
+
</tr>
|
| 106 |
+
<tr>
|
| 107 |
+
<td>A/D or LEFT/RIGHT</td>
|
| 108 |
+
<td>cycle through images in current directory</td>
|
| 109 |
+
</tr>
|
| 110 |
+
<tr>
|
| 111 |
+
<td>W/S or UP/DOWN</td>
|
| 112 |
+
<td>change color (RGB/gray/red/green/blue)</td>
|
| 113 |
+
</tr>
|
| 114 |
+
<tr>
|
| 115 |
+
<td>R / G / B</td>
|
| 116 |
+
<td>toggle between RGB and Red or Green or Blue</td>
|
| 117 |
+
</tr>
|
| 118 |
+
<tr>
|
| 119 |
+
<td>PAGE-UP / PAGE-DOWN</td>
|
| 120 |
+
<td>change to flows and cell prob views (if segmentation computed)</td>
|
| 121 |
+
</tr>
|
| 122 |
+
<tr>
|
| 123 |
+
<td>X</td>
|
| 124 |
+
<td>turn masks ON or OFF</td>
|
| 125 |
+
</tr>
|
| 126 |
+
<tr>
|
| 127 |
+
<td>Z</td>
|
| 128 |
+
<td>toggle outlines ON or OFF</td>
|
| 129 |
+
</tr>
|
| 130 |
+
<tr>
|
| 131 |
+
<td>, / .</td>
|
| 132 |
+
<td>increase / decrease brush size for drawing masks</td>
|
| 133 |
+
</tr>
|
| 134 |
+
</tbody>
|
| 135 |
+
</table>
|
| 136 |
+
<p class="has-line-data" data-line-start="36" data-line-end="37"><strong>Segmentation options
|
| 137 |
+
(2D only) </strong></p>
|
| 138 |
+
<p class="has-line-data" data-line-start="38" data-line-end="39">use GPU: if you have specially
|
| 139 |
+
installed the cuda version of torch, then you can activate this. Due to the size of the
|
| 140 |
+
transformer network, it will greatly speed up the processing time.</p>
|
| 141 |
+
<p class="has-line-data" data-line-start="40" data-line-end="41">There are no channel options
|
| 142 |
+
in v4.0.1+ since all 3 channels are used for segmentation. </p>
|
| 143 |
+
</qt>
|
models/seg_post_model/cellpose/gui/guiparts.py
ADDED
|
@@ -0,0 +1,793 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
from qtpy import QtGui, QtCore
|
| 5 |
+
from qtpy.QtGui import QPixmap, QDoubleValidator
|
| 6 |
+
from qtpy.QtWidgets import QWidget, QDialog, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox, QVBoxLayout
|
| 7 |
+
import pyqtgraph as pg
|
| 8 |
+
import numpy as np
|
| 9 |
+
import pathlib, os
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def stylesheet():
|
| 13 |
+
return """
|
| 14 |
+
QToolTip {
|
| 15 |
+
background-color: black;
|
| 16 |
+
color: white;
|
| 17 |
+
border: black solid 1px
|
| 18 |
+
}
|
| 19 |
+
QComboBox {color: white;
|
| 20 |
+
background-color: rgb(40,40,40);}
|
| 21 |
+
QComboBox::item:enabled { color: white;
|
| 22 |
+
background-color: rgb(40,40,40);
|
| 23 |
+
selection-color: white;
|
| 24 |
+
selection-background-color: rgb(50,100,50);}
|
| 25 |
+
QComboBox::item:!enabled {
|
| 26 |
+
background-color: rgb(40,40,40);
|
| 27 |
+
color: rgb(100,100,100);
|
| 28 |
+
}
|
| 29 |
+
QScrollArea > QWidget > QWidget
|
| 30 |
+
{
|
| 31 |
+
background: transparent;
|
| 32 |
+
border: none;
|
| 33 |
+
margin: 0px 0px 0px 0px;
|
| 34 |
+
}
|
| 35 |
+
|
| 36 |
+
QGroupBox
|
| 37 |
+
{ border: 1px solid white; color: rgb(255,255,255);
|
| 38 |
+
border-radius: 6px;
|
| 39 |
+
margin-top: 8px;
|
| 40 |
+
padding: 0px 0px;}
|
| 41 |
+
|
| 42 |
+
QPushButton:pressed {Text-align: center;
|
| 43 |
+
background-color: rgb(150,50,150);
|
| 44 |
+
border-color: white;
|
| 45 |
+
color:white;}
|
| 46 |
+
QToolTip {
|
| 47 |
+
background-color: black;
|
| 48 |
+
color: white;
|
| 49 |
+
border: black solid 1px
|
| 50 |
+
}
|
| 51 |
+
QPushButton:!pressed {Text-align: center;
|
| 52 |
+
background-color: rgb(50,50,50);
|
| 53 |
+
border-color: white;
|
| 54 |
+
color:white;}
|
| 55 |
+
QToolTip {
|
| 56 |
+
background-color: black;
|
| 57 |
+
color: white;
|
| 58 |
+
border: black solid 1px
|
| 59 |
+
}
|
| 60 |
+
QPushButton:disabled {Text-align: center;
|
| 61 |
+
background-color: rgb(30,30,30);
|
| 62 |
+
border-color: white;
|
| 63 |
+
color:rgb(80,80,80);}
|
| 64 |
+
QToolTip {
|
| 65 |
+
background-color: black;
|
| 66 |
+
color: white;
|
| 67 |
+
border: black solid 1px
|
| 68 |
+
}
|
| 69 |
+
|
| 70 |
+
"""
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
class DarkPalette(QtGui.QPalette):
|
| 74 |
+
"""Class that inherits from pyqtgraph.QtGui.QPalette and renders dark colours for the application.
|
| 75 |
+
(from pykilosort/kilosort4)
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
def __init__(self):
|
| 79 |
+
QtGui.QPalette.__init__(self)
|
| 80 |
+
self.setup()
|
| 81 |
+
|
| 82 |
+
def setup(self):
|
| 83 |
+
self.setColor(QtGui.QPalette.Window, QtGui.QColor(40, 40, 40))
|
| 84 |
+
self.setColor(QtGui.QPalette.WindowText, QtGui.QColor(255, 255, 255))
|
| 85 |
+
self.setColor(QtGui.QPalette.Base, QtGui.QColor(34, 27, 24))
|
| 86 |
+
self.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(53, 50, 47))
|
| 87 |
+
self.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(255, 255, 255))
|
| 88 |
+
self.setColor(QtGui.QPalette.ToolTipText, QtGui.QColor(255, 255, 255))
|
| 89 |
+
self.setColor(QtGui.QPalette.Text, QtGui.QColor(255, 255, 255))
|
| 90 |
+
self.setColor(QtGui.QPalette.Button, QtGui.QColor(53, 50, 47))
|
| 91 |
+
self.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(255, 255, 255))
|
| 92 |
+
self.setColor(QtGui.QPalette.BrightText, QtGui.QColor(255, 0, 0))
|
| 93 |
+
self.setColor(QtGui.QPalette.Link, QtGui.QColor(42, 130, 218))
|
| 94 |
+
self.setColor(QtGui.QPalette.Highlight, QtGui.QColor(42, 130, 218))
|
| 95 |
+
self.setColor(QtGui.QPalette.HighlightedText, QtGui.QColor(0, 0, 0))
|
| 96 |
+
self.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Text,
|
| 97 |
+
QtGui.QColor(128, 128, 128))
|
| 98 |
+
self.setColor(
|
| 99 |
+
QtGui.QPalette.Disabled,
|
| 100 |
+
QtGui.QPalette.ButtonText,
|
| 101 |
+
QtGui.QColor(128, 128, 128),
|
| 102 |
+
)
|
| 103 |
+
self.setColor(
|
| 104 |
+
QtGui.QPalette.Disabled,
|
| 105 |
+
QtGui.QPalette.WindowText,
|
| 106 |
+
QtGui.QColor(128, 128, 128),
|
| 107 |
+
)
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
# def create_channel_choose():
|
| 111 |
+
# # choose channel
|
| 112 |
+
# ChannelChoose = [QComboBox(), QComboBox()]
|
| 113 |
+
# ChannelLabels = []
|
| 114 |
+
# ChannelChoose[0].addItems(["gray", "red", "green", "blue"])
|
| 115 |
+
# ChannelChoose[1].addItems(["none", "red", "green", "blue"])
|
| 116 |
+
# cstr = ["chan to segment:", "chan2 (optional): "]
|
| 117 |
+
# for i in range(2):
|
| 118 |
+
# ChannelLabels.append(QLabel(cstr[i]))
|
| 119 |
+
# if i == 0:
|
| 120 |
+
# ChannelLabels[i].setToolTip(
|
| 121 |
+
# "this is the channel in which the cytoplasm or nuclei exist \
|
| 122 |
+
# that you want to segment")
|
| 123 |
+
# ChannelChoose[i].setToolTip(
|
| 124 |
+
# "this is the channel in which the cytoplasm or nuclei exist \
|
| 125 |
+
# that you want to segment")
|
| 126 |
+
# else:
|
| 127 |
+
# ChannelLabels[i].setToolTip(
|
| 128 |
+
# "if <em>cytoplasm</em> model is chosen, and you also have a \
|
| 129 |
+
# nuclear channel, then choose the nuclear channel for this option")
|
| 130 |
+
# ChannelChoose[i].setToolTip(
|
| 131 |
+
# "if <em>cytoplasm</em> model is chosen, and you also have a \
|
| 132 |
+
# nuclear channel, then choose the nuclear channel for this option")
|
| 133 |
+
|
| 134 |
+
# return ChannelChoose, ChannelLabels
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class ModelButton(QPushButton):
|
| 138 |
+
|
| 139 |
+
def __init__(self, parent, model_name, text):
|
| 140 |
+
super().__init__()
|
| 141 |
+
self.setEnabled(False)
|
| 142 |
+
self.setText(text)
|
| 143 |
+
self.setFont(parent.boldfont)
|
| 144 |
+
self.clicked.connect(lambda: self.press(parent))
|
| 145 |
+
self.model_name = "cpsam"
|
| 146 |
+
|
| 147 |
+
def press(self, parent):
|
| 148 |
+
parent.compute_segmentation(model_name="cpsam")
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class FilterButton(QPushButton):
|
| 152 |
+
|
| 153 |
+
def __init__(self, parent, text):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.setEnabled(False)
|
| 156 |
+
self.model_type = text
|
| 157 |
+
self.setText(text)
|
| 158 |
+
self.setFont(parent.medfont)
|
| 159 |
+
self.clicked.connect(lambda: self.press(parent))
|
| 160 |
+
|
| 161 |
+
def press(self, parent):
|
| 162 |
+
if self.model_type == "filter":
|
| 163 |
+
parent.restore = "filter"
|
| 164 |
+
normalize_params = parent.get_normalize_params()
|
| 165 |
+
if (normalize_params["sharpen_radius"] == 0 and
|
| 166 |
+
normalize_params["smooth_radius"] == 0 and
|
| 167 |
+
normalize_params["tile_norm_blocksize"] == 0):
|
| 168 |
+
print(
|
| 169 |
+
"GUI_ERROR: no filtering settings on (use custom filter settings)")
|
| 170 |
+
parent.restore = None
|
| 171 |
+
return
|
| 172 |
+
parent.restore = self.model_type
|
| 173 |
+
parent.compute_saturation()
|
| 174 |
+
# elif self.model_type != "none":
|
| 175 |
+
# parent.compute_denoise_model(model_type=self.model_type)
|
| 176 |
+
else:
|
| 177 |
+
parent.clear_restore()
|
| 178 |
+
# parent.set_restore_button()
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
class ObservableVariable(QtCore.QObject):
|
| 182 |
+
valueChanged = QtCore.Signal(object)
|
| 183 |
+
|
| 184 |
+
def __init__(self, initial=None):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self._value = initial
|
| 187 |
+
|
| 188 |
+
def set(self, new_value):
|
| 189 |
+
""" Use this method to get emit the value changing and update the ROI count"""
|
| 190 |
+
if new_value != self._value:
|
| 191 |
+
self._value = new_value
|
| 192 |
+
self.valueChanged.emit(new_value)
|
| 193 |
+
|
| 194 |
+
def get(self):
|
| 195 |
+
return self._value
|
| 196 |
+
|
| 197 |
+
def __call__(self):
|
| 198 |
+
return self._value
|
| 199 |
+
|
| 200 |
+
def reset(self):
|
| 201 |
+
self.set(0)
|
| 202 |
+
|
| 203 |
+
def __iadd__(self, amount):
|
| 204 |
+
if not isinstance(amount, (int, float)):
|
| 205 |
+
raise TypeError("Value must be numeric.")
|
| 206 |
+
self.set(self._value + amount)
|
| 207 |
+
return self
|
| 208 |
+
|
| 209 |
+
def __radd__(self, other):
|
| 210 |
+
return other + self._value
|
| 211 |
+
|
| 212 |
+
def __add__(self, other):
|
| 213 |
+
return other + self._value
|
| 214 |
+
|
| 215 |
+
def __isub__(self, amount):
|
| 216 |
+
if not isinstance(amount, (int, float)):
|
| 217 |
+
raise TypeError("Value must be numeric.")
|
| 218 |
+
self.set(self._value - amount)
|
| 219 |
+
return self
|
| 220 |
+
|
| 221 |
+
def __str__(self):
|
| 222 |
+
return str(self._value)
|
| 223 |
+
|
| 224 |
+
def __lt__(self, x):
|
| 225 |
+
return self._value < x
|
| 226 |
+
|
| 227 |
+
def __gt__(self, x):
|
| 228 |
+
return self._value > x
|
| 229 |
+
|
| 230 |
+
def __eq__(self, x):
|
| 231 |
+
return self._value == x
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
class NormalizationSettings(QWidget):
|
| 235 |
+
# TODO
|
| 236 |
+
pass
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
class SegmentationSettings(QWidget):
|
| 240 |
+
""" Container for gui settings. Validation is done automatically so any attributes can
|
| 241 |
+
be acessed without concern.
|
| 242 |
+
"""
|
| 243 |
+
def __init__(self, font):
|
| 244 |
+
super().__init__()
|
| 245 |
+
|
| 246 |
+
# Put everything in a grid layout:
|
| 247 |
+
grid_layout = QGridLayout()
|
| 248 |
+
widget_container = QWidget()
|
| 249 |
+
widget_container.setLayout(grid_layout)
|
| 250 |
+
row = 0
|
| 251 |
+
|
| 252 |
+
########################### Diameter ###########################
|
| 253 |
+
# TODO: Validate inputs
|
| 254 |
+
diam_qlabel = QLabel("diameter:")
|
| 255 |
+
diam_qlabel.setToolTip("diameter of cells in pixels. If not 30, image will be resized to this")
|
| 256 |
+
diam_qlabel.setFont(font)
|
| 257 |
+
grid_layout.addWidget(diam_qlabel, row, 0, 1, 2)
|
| 258 |
+
self.diameter_box = QLineEdit()
|
| 259 |
+
self.diameter_box.setToolTip("diameter of cells in pixels. If not blank, image will be resized relative to 30 pixel cell diameters")
|
| 260 |
+
self.diameter_box.setFont(font)
|
| 261 |
+
self.diameter_box.setFixedWidth(40)
|
| 262 |
+
self.diameter_box.setText(' ')
|
| 263 |
+
grid_layout.addWidget(self.diameter_box, row, 2, 1, 2)
|
| 264 |
+
|
| 265 |
+
row += 1
|
| 266 |
+
|
| 267 |
+
########################### Flow threshold ###########################
|
| 268 |
+
# TODO: Validate inputs
|
| 269 |
+
flow_threshold_qlabel = QLabel("flow\nthreshold:")
|
| 270 |
+
flow_threshold_qlabel.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run")
|
| 271 |
+
flow_threshold_qlabel.setFont(font)
|
| 272 |
+
grid_layout.addWidget(flow_threshold_qlabel, row, 0, 1, 2)
|
| 273 |
+
self.flow_threshold_box = QLineEdit()
|
| 274 |
+
self.flow_threshold_box.setText("0.4")
|
| 275 |
+
self.flow_threshold_box.setFixedWidth(40)
|
| 276 |
+
self.flow_threshold_box.setFont(font)
|
| 277 |
+
grid_layout.addWidget(self.flow_threshold_box, row, 2, 1, 2)
|
| 278 |
+
self.flow_threshold_box.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run")
|
| 279 |
+
|
| 280 |
+
########################### Cellprob threshold ###########################
|
| 281 |
+
# TODO: Validate inputs
|
| 282 |
+
cellprob_qlabel = QLabel("cellprob\nthreshold:")
|
| 283 |
+
cellprob_qlabel.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run")
|
| 284 |
+
cellprob_qlabel.setFont(font)
|
| 285 |
+
grid_layout.addWidget(cellprob_qlabel, row, 4, 1, 2)
|
| 286 |
+
self.cellprob_threshold_box = QLineEdit()
|
| 287 |
+
self.cellprob_threshold_box.setText("0.0")
|
| 288 |
+
self.cellprob_threshold_box.setFixedWidth(40)
|
| 289 |
+
self.cellprob_threshold_box.setFont(font)
|
| 290 |
+
self.cellprob_threshold_box.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run")
|
| 291 |
+
grid_layout.addWidget(self.cellprob_threshold_box, row, 6, 1, 2)
|
| 292 |
+
|
| 293 |
+
row += 1
|
| 294 |
+
|
| 295 |
+
########################### Norm percentiles ###########################
|
| 296 |
+
norm_percentiles_qlabel = QLabel("norm percentiles:")
|
| 297 |
+
norm_percentiles_qlabel.setToolTip("sets normalization percentiles for segmentation and denoising\n(pixels at lower percentile set to 0.0 and at upper set to 1.0 for network)")
|
| 298 |
+
norm_percentiles_qlabel.setFont(font)
|
| 299 |
+
grid_layout.addWidget(norm_percentiles_qlabel, row, 0, 1, 8)
|
| 300 |
+
|
| 301 |
+
row += 1
|
| 302 |
+
validator = QDoubleValidator(0.0, 100.0, 2)
|
| 303 |
+
validator.setNotation(QDoubleValidator.StandardNotation)
|
| 304 |
+
|
| 305 |
+
low_norm_qlabel = QLabel('lower:')
|
| 306 |
+
low_norm_qlabel.setToolTip("pixels at this percentile set to 0 (default 1.0)")
|
| 307 |
+
low_norm_qlabel.setFont(font)
|
| 308 |
+
grid_layout.addWidget(low_norm_qlabel, row, 0, 1, 2)
|
| 309 |
+
self.norm_percentile_low_box = QLineEdit()
|
| 310 |
+
self.norm_percentile_low_box.setText("1.0")
|
| 311 |
+
self.norm_percentile_low_box.setFont(font)
|
| 312 |
+
self.norm_percentile_low_box.setFixedWidth(40)
|
| 313 |
+
self.norm_percentile_low_box.setToolTip("pixels at this percentile set to 0 (default 1.0)")
|
| 314 |
+
self.norm_percentile_low_box.setValidator(validator)
|
| 315 |
+
self.norm_percentile_low_box.editingFinished.connect(self.validate_normalization_range)
|
| 316 |
+
grid_layout.addWidget(self.norm_percentile_low_box, row, 2, 1, 1)
|
| 317 |
+
|
| 318 |
+
high_norm_qlabel = QLabel('upper:')
|
| 319 |
+
high_norm_qlabel.setToolTip("pixels at this percentile set to 1 (default 99.0)")
|
| 320 |
+
high_norm_qlabel.setFont(font)
|
| 321 |
+
grid_layout.addWidget(high_norm_qlabel, row, 4, 1, 2)
|
| 322 |
+
self.norm_percentile_high_box = QLineEdit()
|
| 323 |
+
self.norm_percentile_high_box.setText("99.0")
|
| 324 |
+
self.norm_percentile_high_box.setFont(font)
|
| 325 |
+
self.norm_percentile_high_box.setFixedWidth(40)
|
| 326 |
+
self.norm_percentile_high_box.setToolTip("pixels at this percentile set to 1 (default 99.0)")
|
| 327 |
+
self.norm_percentile_high_box.setValidator(validator)
|
| 328 |
+
self.norm_percentile_high_box.editingFinished.connect(self.validate_normalization_range)
|
| 329 |
+
grid_layout.addWidget(self.norm_percentile_high_box, row, 6, 1, 2)
|
| 330 |
+
|
| 331 |
+
row += 1
|
| 332 |
+
|
| 333 |
+
########################### niter ###########################
|
| 334 |
+
# TODO: change this to follow the same default logic as 'diameter' above
|
| 335 |
+
# TODO: input validation
|
| 336 |
+
niter_qlabel = QLabel("niter dynamics:")
|
| 337 |
+
niter_qlabel.setFont(font)
|
| 338 |
+
niter_qlabel.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria")
|
| 339 |
+
grid_layout.addWidget(niter_qlabel, row, 0, 1, 4)
|
| 340 |
+
self.niter_box = QLineEdit()
|
| 341 |
+
self.niter_box.setText("0")
|
| 342 |
+
self.niter_box.setFixedWidth(40)
|
| 343 |
+
self.niter_box.setFont(font)
|
| 344 |
+
self.niter_box.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria")
|
| 345 |
+
grid_layout.addWidget(self.niter_box, row, 4, 1, 2)
|
| 346 |
+
|
| 347 |
+
self.setLayout(grid_layout)
|
| 348 |
+
|
| 349 |
+
def validate_normalization_range(self):
|
| 350 |
+
low_text = self.norm_percentile_low_box.text()
|
| 351 |
+
high_text = self.norm_percentile_high_box.text()
|
| 352 |
+
|
| 353 |
+
if not low_text or low_text.isspace():
|
| 354 |
+
self.norm_percentile_low_box.setText('1.0')
|
| 355 |
+
low_text = '1.0'
|
| 356 |
+
elif not high_text or high_text.isspace():
|
| 357 |
+
self.norm_percentile_high_box.setText('1.0')
|
| 358 |
+
high_text = '99.0'
|
| 359 |
+
|
| 360 |
+
low = float(low_text)
|
| 361 |
+
high = float(high_text)
|
| 362 |
+
|
| 363 |
+
if low >= high:
|
| 364 |
+
# Invalid: show error and mark fields
|
| 365 |
+
self.norm_percentile_low_box.setStyleSheet("border: 1px solid red;")
|
| 366 |
+
self.norm_percentile_high_box.setStyleSheet("border: 1px solid red;")
|
| 367 |
+
else:
|
| 368 |
+
# Valid: clear style
|
| 369 |
+
self.norm_percentile_low_box.setStyleSheet("")
|
| 370 |
+
self.norm_percentile_high_box.setStyleSheet("")
|
| 371 |
+
|
| 372 |
+
@property
|
| 373 |
+
def low_percentile(self):
|
| 374 |
+
""" Also validate the low input by returning 1.0 if text doesn't work """
|
| 375 |
+
low_text = self.norm_percentile_low_box.text()
|
| 376 |
+
if not low_text or low_text.isspace():
|
| 377 |
+
self.norm_percentile_low_box.setText('1.0')
|
| 378 |
+
low_text = '1.0'
|
| 379 |
+
return float(self.norm_percentile_low_box.text())
|
| 380 |
+
|
| 381 |
+
@property
|
| 382 |
+
def high_percentile(self):
|
| 383 |
+
""" Also validate the high input by returning 99.0 if text doesn't work """
|
| 384 |
+
high_text = self.norm_percentile_high_box.text()
|
| 385 |
+
if not high_text or high_text.isspace():
|
| 386 |
+
self.norm_percentile_high_box.setText('99.0')
|
| 387 |
+
high_text = '99.0'
|
| 388 |
+
return float(self.norm_percentile_high_box.text())
|
| 389 |
+
|
| 390 |
+
@property
|
| 391 |
+
def diameter(self):
|
| 392 |
+
""" Get the diameter from the diameter box, if box isn't a number return None"""
|
| 393 |
+
try:
|
| 394 |
+
d = float(self.diameter_box.text())
|
| 395 |
+
except ValueError:
|
| 396 |
+
d = None
|
| 397 |
+
return d
|
| 398 |
+
|
| 399 |
+
@property
|
| 400 |
+
def flow_threshold(self):
|
| 401 |
+
return float(self.flow_threshold_box.text())
|
| 402 |
+
|
| 403 |
+
@property
|
| 404 |
+
def cellprob_threshold(self):
|
| 405 |
+
return float(self.cellprob_threshold_box.text())
|
| 406 |
+
|
| 407 |
+
@property
|
| 408 |
+
def niter(self):
|
| 409 |
+
num = int(self.niter_box.text())
|
| 410 |
+
if num < 1:
|
| 411 |
+
self.niter_box.setText('200')
|
| 412 |
+
return 200
|
| 413 |
+
else:
|
| 414 |
+
return num
|
| 415 |
+
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
class TrainWindow(QDialog):
|
| 419 |
+
|
| 420 |
+
def __init__(self, parent, model_strings):
|
| 421 |
+
super().__init__(parent)
|
| 422 |
+
self.setGeometry(100, 100, 900, 550)
|
| 423 |
+
self.setWindowTitle("train settings")
|
| 424 |
+
self.win = QWidget(self)
|
| 425 |
+
self.l0 = QGridLayout()
|
| 426 |
+
self.win.setLayout(self.l0)
|
| 427 |
+
|
| 428 |
+
yoff = 0
|
| 429 |
+
qlabel = QLabel("train model w/ images + _seg.npy in current folder >>")
|
| 430 |
+
qlabel.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold))
|
| 431 |
+
|
| 432 |
+
qlabel.setAlignment(QtCore.Qt.AlignVCenter)
|
| 433 |
+
self.l0.addWidget(qlabel, yoff, 0, 1, 2)
|
| 434 |
+
|
| 435 |
+
# choose initial model
|
| 436 |
+
yoff += 1
|
| 437 |
+
self.ModelChoose = QComboBox()
|
| 438 |
+
self.ModelChoose.addItems(model_strings)
|
| 439 |
+
self.ModelChoose.setFixedWidth(150)
|
| 440 |
+
self.ModelChoose.setCurrentIndex(parent.training_params["model_index"])
|
| 441 |
+
self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1)
|
| 442 |
+
qlabel = QLabel("initial model: ")
|
| 443 |
+
qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 444 |
+
self.l0.addWidget(qlabel, yoff, 0, 1, 1)
|
| 445 |
+
|
| 446 |
+
# choose parameters
|
| 447 |
+
labels = ["learning_rate", "weight_decay", "n_epochs", "model_name"]
|
| 448 |
+
self.edits = []
|
| 449 |
+
yoff += 1
|
| 450 |
+
for i, label in enumerate(labels):
|
| 451 |
+
qlabel = QLabel(label)
|
| 452 |
+
qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 453 |
+
self.l0.addWidget(qlabel, i + yoff, 0, 1, 1)
|
| 454 |
+
self.edits.append(QLineEdit())
|
| 455 |
+
self.edits[-1].setText(str(parent.training_params[label]))
|
| 456 |
+
self.edits[-1].setFixedWidth(200)
|
| 457 |
+
self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1)
|
| 458 |
+
|
| 459 |
+
yoff += len(labels)
|
| 460 |
+
|
| 461 |
+
yoff += 1
|
| 462 |
+
self.use_norm = QCheckBox(f"use restored/filtered image")
|
| 463 |
+
self.use_norm.setChecked(True)
|
| 464 |
+
|
| 465 |
+
yoff += 2
|
| 466 |
+
qlabel = QLabel(
|
| 467 |
+
"(to remove files, click cancel then remove \nfrom folder and reopen train window)"
|
| 468 |
+
)
|
| 469 |
+
self.l0.addWidget(qlabel, yoff, 0, 2, 4)
|
| 470 |
+
|
| 471 |
+
# click button
|
| 472 |
+
yoff += 3
|
| 473 |
+
QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel
|
| 474 |
+
self.buttonBox = QDialogButtonBox(QBtn)
|
| 475 |
+
self.buttonBox.accepted.connect(lambda: self.accept(parent))
|
| 476 |
+
self.buttonBox.rejected.connect(self.reject)
|
| 477 |
+
self.l0.addWidget(self.buttonBox, yoff, 0, 1, 4)
|
| 478 |
+
|
| 479 |
+
# list files in folder
|
| 480 |
+
qlabel = QLabel("filenames")
|
| 481 |
+
qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
|
| 482 |
+
self.l0.addWidget(qlabel, 0, 4, 1, 1)
|
| 483 |
+
qlabel = QLabel("# of masks")
|
| 484 |
+
qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold))
|
| 485 |
+
self.l0.addWidget(qlabel, 0, 5, 1, 1)
|
| 486 |
+
|
| 487 |
+
for i in range(10):
|
| 488 |
+
if i > len(parent.train_files) - 1:
|
| 489 |
+
break
|
| 490 |
+
elif i == 9 and len(parent.train_files) > 10:
|
| 491 |
+
label = "..."
|
| 492 |
+
nmasks = "..."
|
| 493 |
+
else:
|
| 494 |
+
label = os.path.split(parent.train_files[i])[-1]
|
| 495 |
+
nmasks = str(parent.train_labels[i].max())
|
| 496 |
+
qlabel = QLabel(label)
|
| 497 |
+
self.l0.addWidget(qlabel, i + 1, 4, 1, 1)
|
| 498 |
+
qlabel = QLabel(nmasks)
|
| 499 |
+
qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter)
|
| 500 |
+
self.l0.addWidget(qlabel, i + 1, 5, 1, 1)
|
| 501 |
+
|
| 502 |
+
def accept(self, parent):
|
| 503 |
+
# set training params
|
| 504 |
+
parent.training_params = {
|
| 505 |
+
"model_index": self.ModelChoose.currentIndex(),
|
| 506 |
+
"learning_rate": float(self.edits[0].text()),
|
| 507 |
+
"weight_decay": float(self.edits[1].text()),
|
| 508 |
+
"n_epochs": int(self.edits[2].text()),
|
| 509 |
+
"model_name": self.edits[3].text(),
|
| 510 |
+
#"use_norm": True if self.use_norm.isChecked() else False,
|
| 511 |
+
}
|
| 512 |
+
self.done(1)
|
| 513 |
+
|
| 514 |
+
|
| 515 |
+
class ExampleGUI(QDialog):
|
| 516 |
+
|
| 517 |
+
def __init__(self, parent=None):
|
| 518 |
+
super(ExampleGUI, self).__init__(parent)
|
| 519 |
+
self.setGeometry(100, 100, 1300, 900)
|
| 520 |
+
self.setWindowTitle("GUI layout")
|
| 521 |
+
self.win = QWidget(self)
|
| 522 |
+
layout = QGridLayout()
|
| 523 |
+
self.win.setLayout(layout)
|
| 524 |
+
guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png")
|
| 525 |
+
guip_path = str(guip_path.resolve())
|
| 526 |
+
pixmap = QPixmap(guip_path)
|
| 527 |
+
label = QLabel(self)
|
| 528 |
+
label.setPixmap(pixmap)
|
| 529 |
+
pixmap.scaled
|
| 530 |
+
layout.addWidget(label, 0, 0, 1, 1)
|
| 531 |
+
|
| 532 |
+
|
| 533 |
+
class HelpWindow(QDialog):
|
| 534 |
+
|
| 535 |
+
def __init__(self, parent=None):
|
| 536 |
+
super(HelpWindow, self).__init__(parent)
|
| 537 |
+
self.setGeometry(100, 50, 700, 1000)
|
| 538 |
+
self.setWindowTitle("cellpose help")
|
| 539 |
+
self.win = QWidget(self)
|
| 540 |
+
layout = QGridLayout()
|
| 541 |
+
self.win.setLayout(layout)
|
| 542 |
+
|
| 543 |
+
text_file = pathlib.Path(__file__).parent.joinpath("guihelpwindowtext.html")
|
| 544 |
+
with open(str(text_file.resolve()), "r") as f:
|
| 545 |
+
text = f.read()
|
| 546 |
+
|
| 547 |
+
label = QLabel(text)
|
| 548 |
+
label.setFont(QtGui.QFont("Arial", 8))
|
| 549 |
+
label.setWordWrap(True)
|
| 550 |
+
layout.addWidget(label, 0, 0, 1, 1)
|
| 551 |
+
self.show()
|
| 552 |
+
|
| 553 |
+
|
| 554 |
+
class TrainHelpWindow(QDialog):
|
| 555 |
+
|
| 556 |
+
def __init__(self, parent=None):
|
| 557 |
+
super(TrainHelpWindow, self).__init__(parent)
|
| 558 |
+
self.setGeometry(100, 50, 700, 300)
|
| 559 |
+
self.setWindowTitle("training instructions")
|
| 560 |
+
self.win = QWidget(self)
|
| 561 |
+
layout = QGridLayout()
|
| 562 |
+
self.win.setLayout(layout)
|
| 563 |
+
|
| 564 |
+
text_file = pathlib.Path(__file__).parent.joinpath(
|
| 565 |
+
"guitrainhelpwindowtext.html")
|
| 566 |
+
with open(str(text_file.resolve()), "r") as f:
|
| 567 |
+
text = f.read()
|
| 568 |
+
|
| 569 |
+
label = QLabel(text)
|
| 570 |
+
label.setFont(QtGui.QFont("Arial", 8))
|
| 571 |
+
label.setWordWrap(True)
|
| 572 |
+
layout.addWidget(label, 0, 0, 1, 1)
|
| 573 |
+
self.show()
|
| 574 |
+
|
| 575 |
+
|
| 576 |
+
class ViewBoxNoRightDrag(pg.ViewBox):
|
| 577 |
+
|
| 578 |
+
def __init__(self, parent=None, border=None, lockAspect=False, enableMouse=True,
|
| 579 |
+
invertY=False, enableMenu=True, name=None, invertX=False):
|
| 580 |
+
pg.ViewBox.__init__(self, None, border, lockAspect, enableMouse, invertY,
|
| 581 |
+
enableMenu, name, invertX)
|
| 582 |
+
self.parent = parent
|
| 583 |
+
self.axHistoryPointer = -1
|
| 584 |
+
|
| 585 |
+
def keyPressEvent(self, ev):
|
| 586 |
+
"""
|
| 587 |
+
This routine should capture key presses in the current view box.
|
| 588 |
+
The following events are implemented:
|
| 589 |
+
+/= : moves forward in the zooming stack (if it exists)
|
| 590 |
+
- : moves backward in the zooming stack (if it exists)
|
| 591 |
+
|
| 592 |
+
"""
|
| 593 |
+
ev.accept()
|
| 594 |
+
if ev.text() == "-":
|
| 595 |
+
self.scaleBy([1.1, 1.1])
|
| 596 |
+
elif ev.text() in ["+", "="]:
|
| 597 |
+
self.scaleBy([0.9, 0.9])
|
| 598 |
+
else:
|
| 599 |
+
ev.ignore()
|
| 600 |
+
|
| 601 |
+
|
| 602 |
+
class ImageDraw(pg.ImageItem):
|
| 603 |
+
"""
|
| 604 |
+
**Bases:** :class:`GraphicsObject <pyqtgraph.GraphicsObject>`
|
| 605 |
+
GraphicsObject displaying an image. Optimized for rapid update (ie video display).
|
| 606 |
+
This item displays either a 2D numpy array (height, width) or
|
| 607 |
+
a 3D array (height, width, RGBa). This array is optionally scaled (see
|
| 608 |
+
:func:`setLevels <pyqtgraph.ImageItem.setLevels>`) and/or colored
|
| 609 |
+
with a lookup table (see :func:`setLookupTable <pyqtgraph.ImageItem.setLookupTable>`)
|
| 610 |
+
before being displayed.
|
| 611 |
+
ImageItem is frequently used in conjunction with
|
| 612 |
+
:class:`HistogramLUTItem <pyqtgraph.HistogramLUTItem>` or
|
| 613 |
+
:class:`HistogramLUTWidget <pyqtgraph.HistogramLUTWidget>` to provide a GUI
|
| 614 |
+
for controlling the levels and lookup table used to display the image.
|
| 615 |
+
"""
|
| 616 |
+
|
| 617 |
+
sigImageChanged = QtCore.Signal()
|
| 618 |
+
|
| 619 |
+
def __init__(self, image=None, viewbox=None, parent=None, **kargs):
|
| 620 |
+
super(ImageDraw, self).__init__()
|
| 621 |
+
self.levels = np.array([0, 255])
|
| 622 |
+
self.lut = None
|
| 623 |
+
self.autoDownsample = False
|
| 624 |
+
self.axisOrder = "row-major"
|
| 625 |
+
self.removable = False
|
| 626 |
+
|
| 627 |
+
self.parent = parent
|
| 628 |
+
self.setDrawKernel(kernel_size=self.parent.brush_size)
|
| 629 |
+
self.parent.current_stroke = []
|
| 630 |
+
self.parent.in_stroke = False
|
| 631 |
+
|
| 632 |
+
def mouseClickEvent(self, ev):
|
| 633 |
+
if (self.parent.masksOn or
|
| 634 |
+
self.parent.outlinesOn) and not self.parent.removing_region:
|
| 635 |
+
is_right_click = ev.button() == QtCore.Qt.RightButton
|
| 636 |
+
if self.parent.loaded \
|
| 637 |
+
and (is_right_click or ev.modifiers() & QtCore.Qt.ShiftModifier and not ev.double())\
|
| 638 |
+
and not self.parent.deleting_multiple:
|
| 639 |
+
if not self.parent.in_stroke:
|
| 640 |
+
ev.accept()
|
| 641 |
+
self.create_start(ev.pos())
|
| 642 |
+
self.parent.stroke_appended = False
|
| 643 |
+
self.parent.in_stroke = True
|
| 644 |
+
self.drawAt(ev.pos(), ev)
|
| 645 |
+
else:
|
| 646 |
+
ev.accept()
|
| 647 |
+
self.end_stroke()
|
| 648 |
+
self.parent.in_stroke = False
|
| 649 |
+
elif not self.parent.in_stroke:
|
| 650 |
+
y, x = int(ev.pos().y()), int(ev.pos().x())
|
| 651 |
+
if y >= 0 and y < self.parent.Ly and x >= 0 and x < self.parent.Lx:
|
| 652 |
+
if ev.button() == QtCore.Qt.LeftButton and not ev.double():
|
| 653 |
+
idx = self.parent.cellpix[self.parent.currentZ][y, x]
|
| 654 |
+
if idx > 0:
|
| 655 |
+
if ev.modifiers() & QtCore.Qt.ControlModifier:
|
| 656 |
+
# delete mask selected
|
| 657 |
+
self.parent.remove_cell(idx)
|
| 658 |
+
elif ev.modifiers() & QtCore.Qt.AltModifier:
|
| 659 |
+
self.parent.merge_cells(idx)
|
| 660 |
+
elif self.parent.masksOn and not self.parent.deleting_multiple:
|
| 661 |
+
self.parent.unselect_cell()
|
| 662 |
+
self.parent.select_cell(idx)
|
| 663 |
+
elif self.parent.deleting_multiple:
|
| 664 |
+
if idx in self.parent.removing_cells_list:
|
| 665 |
+
self.parent.unselect_cell_multi(idx)
|
| 666 |
+
self.parent.removing_cells_list.remove(idx)
|
| 667 |
+
else:
|
| 668 |
+
self.parent.select_cell_multi(idx)
|
| 669 |
+
self.parent.removing_cells_list.append(idx)
|
| 670 |
+
|
| 671 |
+
elif self.parent.masksOn and not self.parent.deleting_multiple:
|
| 672 |
+
self.parent.unselect_cell()
|
| 673 |
+
|
| 674 |
+
def mouseDragEvent(self, ev):
|
| 675 |
+
ev.ignore()
|
| 676 |
+
return
|
| 677 |
+
|
| 678 |
+
def hoverEvent(self, ev):
|
| 679 |
+
if self.parent.in_stroke:
|
| 680 |
+
if self.parent.in_stroke:
|
| 681 |
+
# continue stroke if not at start
|
| 682 |
+
self.drawAt(ev.pos())
|
| 683 |
+
if self.is_at_start(ev.pos()):
|
| 684 |
+
self.end_stroke()
|
| 685 |
+
else:
|
| 686 |
+
ev.acceptClicks(QtCore.Qt.RightButton)
|
| 687 |
+
|
| 688 |
+
def create_start(self, pos):
|
| 689 |
+
self.scatter = pg.ScatterPlotItem([pos.x()], [pos.y()], pxMode=False,
|
| 690 |
+
pen=pg.mkPen(color=(255, 0, 0),
|
| 691 |
+
width=self.parent.brush_size),
|
| 692 |
+
size=max(3 * 2,
|
| 693 |
+
self.parent.brush_size * 1.8 * 2),
|
| 694 |
+
brush=None)
|
| 695 |
+
self.parent.p0.addItem(self.scatter)
|
| 696 |
+
|
| 697 |
+
def is_at_start(self, pos):
|
| 698 |
+
thresh_out = max(6, self.parent.brush_size * 3)
|
| 699 |
+
thresh_in = max(3, self.parent.brush_size * 1.8)
|
| 700 |
+
# first check if you ever left the start
|
| 701 |
+
if len(self.parent.current_stroke) > 3:
|
| 702 |
+
stroke = np.array(self.parent.current_stroke)
|
| 703 |
+
dist = (((stroke[1:, 1:] -
|
| 704 |
+
stroke[:1, 1:][np.newaxis, :, :])**2).sum(axis=-1))**0.5
|
| 705 |
+
dist = dist.flatten()
|
| 706 |
+
has_left = (dist > thresh_out).nonzero()[0]
|
| 707 |
+
if len(has_left) > 0:
|
| 708 |
+
first_left = np.sort(has_left)[0]
|
| 709 |
+
has_returned = (dist[max(4, first_left + 1):] < thresh_in).sum()
|
| 710 |
+
if has_returned > 0:
|
| 711 |
+
return True
|
| 712 |
+
else:
|
| 713 |
+
return False
|
| 714 |
+
else:
|
| 715 |
+
return False
|
| 716 |
+
|
| 717 |
+
def end_stroke(self):
|
| 718 |
+
self.parent.p0.removeItem(self.scatter)
|
| 719 |
+
if not self.parent.stroke_appended:
|
| 720 |
+
self.parent.strokes.append(self.parent.current_stroke)
|
| 721 |
+
self.parent.stroke_appended = True
|
| 722 |
+
self.parent.current_stroke = np.array(self.parent.current_stroke)
|
| 723 |
+
ioutline = self.parent.current_stroke[:, 3] == 1
|
| 724 |
+
self.parent.current_point_set.append(
|
| 725 |
+
list(self.parent.current_stroke[ioutline]))
|
| 726 |
+
self.parent.current_stroke = []
|
| 727 |
+
if self.parent.autosave:
|
| 728 |
+
self.parent.add_set()
|
| 729 |
+
if len(self.parent.current_point_set) and len(
|
| 730 |
+
self.parent.current_point_set[0]) > 0 and self.parent.autosave:
|
| 731 |
+
self.parent.add_set()
|
| 732 |
+
self.parent.in_stroke = False
|
| 733 |
+
|
| 734 |
+
def tabletEvent(self, ev):
|
| 735 |
+
pass
|
| 736 |
+
|
| 737 |
+
def drawAt(self, pos, ev=None):
|
| 738 |
+
mask = self.strokemask
|
| 739 |
+
stroke = self.parent.current_stroke
|
| 740 |
+
pos = [int(pos.y()), int(pos.x())]
|
| 741 |
+
dk = self.drawKernel
|
| 742 |
+
kc = self.drawKernelCenter
|
| 743 |
+
sx = [0, dk.shape[0]]
|
| 744 |
+
sy = [0, dk.shape[1]]
|
| 745 |
+
tx = [pos[0] - kc[0], pos[0] - kc[0] + dk.shape[0]]
|
| 746 |
+
ty = [pos[1] - kc[1], pos[1] - kc[1] + dk.shape[1]]
|
| 747 |
+
kcent = kc.copy()
|
| 748 |
+
if tx[0] <= 0:
|
| 749 |
+
sx[0] = 0
|
| 750 |
+
sx[1] = kc[0] + 1
|
| 751 |
+
tx = sx
|
| 752 |
+
kcent[0] = 0
|
| 753 |
+
if ty[0] <= 0:
|
| 754 |
+
sy[0] = 0
|
| 755 |
+
sy[1] = kc[1] + 1
|
| 756 |
+
ty = sy
|
| 757 |
+
kcent[1] = 0
|
| 758 |
+
if tx[1] >= self.parent.Ly - 1:
|
| 759 |
+
sx[0] = dk.shape[0] - kc[0] - 1
|
| 760 |
+
sx[1] = dk.shape[0]
|
| 761 |
+
tx[0] = self.parent.Ly - kc[0] - 1
|
| 762 |
+
tx[1] = self.parent.Ly
|
| 763 |
+
kcent[0] = tx[1] - tx[0] - 1
|
| 764 |
+
if ty[1] >= self.parent.Lx - 1:
|
| 765 |
+
sy[0] = dk.shape[1] - kc[1] - 1
|
| 766 |
+
sy[1] = dk.shape[1]
|
| 767 |
+
ty[0] = self.parent.Lx - kc[1] - 1
|
| 768 |
+
ty[1] = self.parent.Lx
|
| 769 |
+
kcent[1] = ty[1] - ty[0] - 1
|
| 770 |
+
|
| 771 |
+
ts = (slice(tx[0], tx[1]), slice(ty[0], ty[1]))
|
| 772 |
+
ss = (slice(sx[0], sx[1]), slice(sy[0], sy[1]))
|
| 773 |
+
self.image[ts] = mask[ss]
|
| 774 |
+
|
| 775 |
+
for ky, y in enumerate(np.arange(ty[0], ty[1], 1, int)):
|
| 776 |
+
for kx, x in enumerate(np.arange(tx[0], tx[1], 1, int)):
|
| 777 |
+
iscent = np.logical_and(kx == kcent[0], ky == kcent[1])
|
| 778 |
+
stroke.append([self.parent.currentZ, x, y, iscent])
|
| 779 |
+
self.updateImage()
|
| 780 |
+
|
| 781 |
+
def setDrawKernel(self, kernel_size=3):
|
| 782 |
+
bs = kernel_size
|
| 783 |
+
kernel = np.ones((bs, bs), np.uint8)
|
| 784 |
+
self.drawKernel = kernel
|
| 785 |
+
self.drawKernelCenter = [
|
| 786 |
+
int(np.floor(kernel.shape[0] / 2)),
|
| 787 |
+
int(np.floor(kernel.shape[1] / 2))
|
| 788 |
+
]
|
| 789 |
+
onmask = 255 * kernel[:, :, np.newaxis]
|
| 790 |
+
offmask = np.zeros((bs, bs, 1))
|
| 791 |
+
opamask = 100 * kernel[:, :, np.newaxis]
|
| 792 |
+
self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1)
|
| 793 |
+
self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1)
|
models/seg_post_model/cellpose/gui/guitrainhelpwindowtext.html
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<qt>
|
| 2 |
+
Check out this <a href="https://youtu.be/3Y1VKcxjNy4">video</a> to learn the process.
|
| 3 |
+
<ol>
|
| 4 |
+
<li>Drag and drop an image from a folder of images with a similar style (like similar cell types).</li>
|
| 5 |
+
<li>Run the built-in models on one of the images using the "model zoo" and find the one that works best for your
|
| 6 |
+
data. Make sure that if you have a nuclear channel you have selected it for CHAN2.
|
| 7 |
+
</li>
|
| 8 |
+
<li>Fix the labelling by drawing new ROIs (right-click) and deleting incorrect ones (CTRL+click). The GUI
|
| 9 |
+
autosaves any manual changes (but does not autosave after running the model, for that click CTRL+S). The
|
| 10 |
+
segmentation is saved in a "_seg.npy" file.
|
| 11 |
+
</li>
|
| 12 |
+
<li> Go to the "Models" menu in the File bar at the top and click "Train new model..." or use shortcut CTRL+T.
|
| 13 |
+
</li>
|
| 14 |
+
<li> Choose the pretrained model to start the training from (the model you used in #2), and type in the model
|
| 15 |
+
name that you want to use. The other parameters should work well in general for most data types. Then click
|
| 16 |
+
OK.
|
| 17 |
+
</li>
|
| 18 |
+
<li> The model will train (much faster if you have a GPU) and then auto-run on the next image in the folder.
|
| 19 |
+
Next you can repeat #3-#5 as many times as is necessary.
|
| 20 |
+
</li>
|
| 21 |
+
<li> The trained model is available to use in the future in the GUI in the "custom model" section and is saved
|
| 22 |
+
in your image folder.
|
| 23 |
+
</li>
|
| 24 |
+
</ol>
|
| 25 |
+
</qt>
|
models/seg_post_model/cellpose/gui/io.py
ADDED
|
@@ -0,0 +1,634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import os, gc
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
import fastremap
|
| 8 |
+
|
| 9 |
+
from ..io import imread, imread_2D, imread_3D, imsave, outlines_to_text, add_model, remove_model, save_rois
|
| 10 |
+
from ..models import normalize_default, MODEL_DIR, MODEL_LIST_PATH, get_user_models
|
| 11 |
+
from ..utils import masks_to_outlines, outlines_list
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
import qtpy
|
| 15 |
+
from qtpy.QtWidgets import QFileDialog
|
| 16 |
+
GUI = True
|
| 17 |
+
except:
|
| 18 |
+
GUI = False
|
| 19 |
+
|
| 20 |
+
try:
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
MATPLOTLIB = True
|
| 23 |
+
except:
|
| 24 |
+
MATPLOTLIB = False
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def _init_model_list(parent):
|
| 28 |
+
MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 29 |
+
parent.model_list_path = MODEL_LIST_PATH
|
| 30 |
+
parent.model_strings = get_user_models()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def _add_model(parent, filename=None, load_model=True):
|
| 34 |
+
if filename is None:
|
| 35 |
+
name = QFileDialog.getOpenFileName(parent, "Add model to GUI")
|
| 36 |
+
filename = name[0]
|
| 37 |
+
add_model(filename)
|
| 38 |
+
fname = os.path.split(filename)[-1]
|
| 39 |
+
parent.ModelChooseC.addItems([fname])
|
| 40 |
+
parent.model_strings.append(fname)
|
| 41 |
+
|
| 42 |
+
for ind, model_string in enumerate(parent.model_strings[:-1]):
|
| 43 |
+
if model_string == fname:
|
| 44 |
+
_remove_model(parent, ind=ind + 1, verbose=False)
|
| 45 |
+
|
| 46 |
+
parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
|
| 47 |
+
if load_model:
|
| 48 |
+
parent.model_choose(custom=True)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def _remove_model(parent, ind=None, verbose=True):
|
| 52 |
+
if ind is None:
|
| 53 |
+
ind = parent.ModelChooseC.currentIndex()
|
| 54 |
+
if ind > 0:
|
| 55 |
+
ind -= 1
|
| 56 |
+
parent.ModelChooseC.removeItem(ind + 1)
|
| 57 |
+
del parent.model_strings[ind]
|
| 58 |
+
# remove model from txt path
|
| 59 |
+
modelstr = parent.ModelChooseC.currentText()
|
| 60 |
+
remove_model(modelstr)
|
| 61 |
+
if len(parent.model_strings) > 0:
|
| 62 |
+
parent.ModelChooseC.setCurrentIndex(len(parent.model_strings))
|
| 63 |
+
else:
|
| 64 |
+
parent.ModelChooseC.setCurrentIndex(0)
|
| 65 |
+
else:
|
| 66 |
+
print("ERROR: no model selected to delete")
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _get_train_set(image_names):
|
| 70 |
+
""" get training data and labels for images in current folder image_names"""
|
| 71 |
+
train_data, train_labels, train_files = [], [], []
|
| 72 |
+
restore = None
|
| 73 |
+
normalize_params = normalize_default
|
| 74 |
+
for image_name_full in image_names:
|
| 75 |
+
image_name = os.path.splitext(image_name_full)[0]
|
| 76 |
+
label_name = None
|
| 77 |
+
if os.path.exists(image_name + "_seg.npy"):
|
| 78 |
+
dat = np.load(image_name + "_seg.npy", allow_pickle=True).item()
|
| 79 |
+
masks = dat["masks"].squeeze()
|
| 80 |
+
if masks.ndim == 2:
|
| 81 |
+
fastremap.renumber(masks, in_place=True)
|
| 82 |
+
label_name = image_name + "_seg.npy"
|
| 83 |
+
else:
|
| 84 |
+
print(f"GUI_INFO: _seg.npy found for {image_name} but masks.ndim!=2")
|
| 85 |
+
if "img_restore" in dat:
|
| 86 |
+
data = dat["img_restore"].squeeze()
|
| 87 |
+
restore = dat["restore"]
|
| 88 |
+
else:
|
| 89 |
+
data = imread(image_name_full)
|
| 90 |
+
normalize_params = dat[
|
| 91 |
+
"normalize_params"] if "normalize_params" in dat else normalize_default
|
| 92 |
+
if label_name is not None:
|
| 93 |
+
train_files.append(image_name_full)
|
| 94 |
+
train_data.append(data)
|
| 95 |
+
train_labels.append(masks)
|
| 96 |
+
if restore:
|
| 97 |
+
print(f"GUI_INFO: using {restore} images (dat['img_restore'])")
|
| 98 |
+
return train_data, train_labels, train_files, restore, normalize_params
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def _load_image(parent, filename=None, load_seg=True, load_3D=False):
|
| 102 |
+
""" load image with filename; if None, open QFileDialog
|
| 103 |
+
if image is grey change view to default to grey scale
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
if parent.load_3D:
|
| 107 |
+
load_3D = True
|
| 108 |
+
|
| 109 |
+
if filename is None:
|
| 110 |
+
name = QFileDialog.getOpenFileName(parent, "Load image")
|
| 111 |
+
filename = name[0]
|
| 112 |
+
if filename == "":
|
| 113 |
+
return
|
| 114 |
+
manual_file = os.path.splitext(filename)[0] + "_seg.npy"
|
| 115 |
+
load_mask = False
|
| 116 |
+
if load_seg:
|
| 117 |
+
if os.path.isfile(manual_file) and not parent.autoloadMasks.isChecked():
|
| 118 |
+
if filename is not None:
|
| 119 |
+
image = (imread_2D(filename) if not load_3D else
|
| 120 |
+
imread_3D(filename))
|
| 121 |
+
else:
|
| 122 |
+
image = None
|
| 123 |
+
_load_seg(parent, manual_file, image=image, image_file=filename,
|
| 124 |
+
load_3D=load_3D)
|
| 125 |
+
return
|
| 126 |
+
elif parent.autoloadMasks.isChecked():
|
| 127 |
+
mask_file = os.path.splitext(filename)[0] + "_masks" + os.path.splitext(
|
| 128 |
+
filename)[-1]
|
| 129 |
+
mask_file = os.path.splitext(filename)[
|
| 130 |
+
0] + "_masks.tif" if not os.path.isfile(mask_file) else mask_file
|
| 131 |
+
load_mask = True if os.path.isfile(mask_file) else False
|
| 132 |
+
try:
|
| 133 |
+
print(f"GUI_INFO: loading image: {filename}")
|
| 134 |
+
if not load_3D:
|
| 135 |
+
image = imread_2D(filename)
|
| 136 |
+
else:
|
| 137 |
+
image = imread_3D(filename)
|
| 138 |
+
parent.loaded = True
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print("ERROR: images not compatible")
|
| 141 |
+
print(f"ERROR: {e}")
|
| 142 |
+
|
| 143 |
+
if parent.loaded:
|
| 144 |
+
parent.reset()
|
| 145 |
+
parent.filename = filename
|
| 146 |
+
filename = os.path.split(parent.filename)[-1]
|
| 147 |
+
_initialize_images(parent, image, load_3D=load_3D)
|
| 148 |
+
parent.loaded = True
|
| 149 |
+
parent.enable_buttons()
|
| 150 |
+
if load_mask:
|
| 151 |
+
_load_masks(parent, filename=mask_file)
|
| 152 |
+
|
| 153 |
+
# check if gray and adjust viewer:
|
| 154 |
+
if len(np.unique(image[..., 1:])) == 1:
|
| 155 |
+
parent.color = 4
|
| 156 |
+
parent.RGBDropDown.setCurrentIndex(4) # gray
|
| 157 |
+
parent.update_plot()
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def _initialize_images(parent, image, load_3D=False):
|
| 161 |
+
""" format image for GUI
|
| 162 |
+
|
| 163 |
+
assumes image is Z x W x H x C
|
| 164 |
+
|
| 165 |
+
"""
|
| 166 |
+
load_3D = parent.load_3D if load_3D is False else load_3D
|
| 167 |
+
|
| 168 |
+
parent.stack = image
|
| 169 |
+
print(f"GUI_INFO: image shape: {image.shape}")
|
| 170 |
+
if load_3D:
|
| 171 |
+
parent.NZ = len(parent.stack)
|
| 172 |
+
parent.scroll.setMaximum(parent.NZ - 1)
|
| 173 |
+
else:
|
| 174 |
+
parent.NZ = 1
|
| 175 |
+
parent.stack = parent.stack[np.newaxis, ...]
|
| 176 |
+
|
| 177 |
+
img_min = image.min()
|
| 178 |
+
img_max = image.max()
|
| 179 |
+
parent.stack = parent.stack.astype(np.float32)
|
| 180 |
+
parent.stack -= img_min
|
| 181 |
+
if img_max > img_min + 1e-3:
|
| 182 |
+
parent.stack /= (img_max - img_min)
|
| 183 |
+
parent.stack *= 255
|
| 184 |
+
|
| 185 |
+
if load_3D:
|
| 186 |
+
print("GUI_INFO: converted to float and normalized values to 0.0->255.0")
|
| 187 |
+
|
| 188 |
+
del image
|
| 189 |
+
gc.collect()
|
| 190 |
+
|
| 191 |
+
parent.imask = 0
|
| 192 |
+
parent.Ly, parent.Lx = parent.stack.shape[-3:-1]
|
| 193 |
+
parent.Ly0, parent.Lx0 = parent.stack.shape[-3:-1]
|
| 194 |
+
parent.layerz = 255 * np.ones((parent.Ly, parent.Lx, 4), "uint8")
|
| 195 |
+
if hasattr(parent, "stack_filtered"):
|
| 196 |
+
parent.Lyr, parent.Lxr = parent.stack_filtered.shape[-3:-1]
|
| 197 |
+
elif parent.restore and "upsample" in parent.restore:
|
| 198 |
+
parent.Lyr, parent.Lxr = int(parent.Ly * parent.ratio), int(parent.Lx *
|
| 199 |
+
parent.ratio)
|
| 200 |
+
else:
|
| 201 |
+
parent.Lyr, parent.Lxr = parent.Ly, parent.Lx
|
| 202 |
+
parent.clear_all()
|
| 203 |
+
|
| 204 |
+
if not hasattr(parent, "stack_filtered") and parent.restore:
|
| 205 |
+
print("GUI_INFO: no 'img_restore' found, applying current settings")
|
| 206 |
+
parent.compute_restore()
|
| 207 |
+
|
| 208 |
+
if parent.autobtn.isChecked():
|
| 209 |
+
if parent.restore is None or parent.restore != "filter":
|
| 210 |
+
print(
|
| 211 |
+
"GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)"
|
| 212 |
+
)
|
| 213 |
+
parent.compute_saturation()
|
| 214 |
+
# elif len(parent.saturation) != parent.NZ:
|
| 215 |
+
# parent.saturation = []
|
| 216 |
+
# for r in range(3):
|
| 217 |
+
# parent.saturation.append([])
|
| 218 |
+
# for n in range(parent.NZ):
|
| 219 |
+
# parent.saturation[-1].append([0, 255])
|
| 220 |
+
# parent.sliders[r].setValue([0, 255])
|
| 221 |
+
parent.compute_scale()
|
| 222 |
+
parent.track_changes = []
|
| 223 |
+
|
| 224 |
+
if load_3D:
|
| 225 |
+
parent.currentZ = int(np.floor(parent.NZ / 2))
|
| 226 |
+
parent.scroll.setValue(parent.currentZ)
|
| 227 |
+
parent.zpos.setText(str(parent.currentZ))
|
| 228 |
+
else:
|
| 229 |
+
parent.currentZ = 0
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False):
|
| 233 |
+
""" load *_seg.npy with filename; if None, open QFileDialog """
|
| 234 |
+
if filename is None:
|
| 235 |
+
name = QFileDialog.getOpenFileName(parent, "Load labelled data", filter="*.npy")
|
| 236 |
+
filename = name[0]
|
| 237 |
+
try:
|
| 238 |
+
dat = np.load(filename, allow_pickle=True).item()
|
| 239 |
+
# check if there are keys in filename
|
| 240 |
+
dat["outlines"]
|
| 241 |
+
parent.loaded = True
|
| 242 |
+
except:
|
| 243 |
+
parent.loaded = False
|
| 244 |
+
print("ERROR: not NPY")
|
| 245 |
+
return
|
| 246 |
+
|
| 247 |
+
parent.reset()
|
| 248 |
+
if image is None:
|
| 249 |
+
found_image = False
|
| 250 |
+
if "filename" in dat:
|
| 251 |
+
parent.filename = dat["filename"]
|
| 252 |
+
if os.path.isfile(parent.filename):
|
| 253 |
+
parent.filename = dat["filename"]
|
| 254 |
+
found_image = True
|
| 255 |
+
else:
|
| 256 |
+
imgname = os.path.split(parent.filename)[1]
|
| 257 |
+
root = os.path.split(filename)[0]
|
| 258 |
+
parent.filename = root + "/" + imgname
|
| 259 |
+
if os.path.isfile(parent.filename):
|
| 260 |
+
found_image = True
|
| 261 |
+
if found_image:
|
| 262 |
+
try:
|
| 263 |
+
print(parent.filename)
|
| 264 |
+
image = (imread_2D(parent.filename) if not load_3D else
|
| 265 |
+
imread_3D(parent.filename))
|
| 266 |
+
except:
|
| 267 |
+
parent.loaded = False
|
| 268 |
+
found_image = False
|
| 269 |
+
print("ERROR: cannot find image file, loading from npy")
|
| 270 |
+
if not found_image:
|
| 271 |
+
parent.filename = filename[:-8]
|
| 272 |
+
print(parent.filename)
|
| 273 |
+
if "img" in dat:
|
| 274 |
+
image = dat["img"]
|
| 275 |
+
else:
|
| 276 |
+
print("ERROR: no image file found and no image in npy")
|
| 277 |
+
return
|
| 278 |
+
else:
|
| 279 |
+
parent.filename = image_file
|
| 280 |
+
|
| 281 |
+
parent.restore = None
|
| 282 |
+
parent.ratio = 1.
|
| 283 |
+
|
| 284 |
+
if "normalize_params" in dat:
|
| 285 |
+
parent.set_normalize_params(dat["normalize_params"])
|
| 286 |
+
|
| 287 |
+
_initialize_images(parent, image, load_3D=load_3D)
|
| 288 |
+
print(parent.stack.shape)
|
| 289 |
+
|
| 290 |
+
if "outlines" in dat:
|
| 291 |
+
if isinstance(dat["outlines"], list):
|
| 292 |
+
# old way of saving files
|
| 293 |
+
dat["outlines"] = dat["outlines"][::-1]
|
| 294 |
+
for k, outline in enumerate(dat["outlines"]):
|
| 295 |
+
if "colors" in dat:
|
| 296 |
+
color = dat["colors"][k]
|
| 297 |
+
else:
|
| 298 |
+
col_rand = np.random.randint(1000)
|
| 299 |
+
color = parent.colormap[col_rand, :3]
|
| 300 |
+
median = parent.add_mask(points=outline, color=color)
|
| 301 |
+
if median is not None:
|
| 302 |
+
parent.cellcolors = np.append(parent.cellcolors,
|
| 303 |
+
color[np.newaxis, :], axis=0)
|
| 304 |
+
parent.ncells += 1
|
| 305 |
+
else:
|
| 306 |
+
if dat["masks"].min() == -1:
|
| 307 |
+
dat["masks"] += 1
|
| 308 |
+
dat["outlines"] += 1
|
| 309 |
+
parent.ncells.set(dat["masks"].max())
|
| 310 |
+
if "colors" in dat and len(dat["colors"]) == dat["masks"].max():
|
| 311 |
+
colors = dat["colors"]
|
| 312 |
+
else:
|
| 313 |
+
colors = parent.colormap[:parent.ncells.get(), :3]
|
| 314 |
+
|
| 315 |
+
_masks_to_gui(parent, dat["masks"], outlines=dat["outlines"], colors=colors)
|
| 316 |
+
|
| 317 |
+
parent.draw_layer()
|
| 318 |
+
|
| 319 |
+
if "manual_changes" in dat:
|
| 320 |
+
parent.track_changes = dat["manual_changes"]
|
| 321 |
+
print("GUI_INFO: loaded in previous changes")
|
| 322 |
+
if "zdraw" in dat:
|
| 323 |
+
parent.zdraw = dat["zdraw"]
|
| 324 |
+
else:
|
| 325 |
+
parent.zdraw = [None for n in range(parent.ncells.get())]
|
| 326 |
+
parent.loaded = True
|
| 327 |
+
else:
|
| 328 |
+
parent.clear_all()
|
| 329 |
+
|
| 330 |
+
parent.ismanual = np.zeros(parent.ncells.get(), bool)
|
| 331 |
+
if "ismanual" in dat:
|
| 332 |
+
if len(dat["ismanual"]) == parent.ncells:
|
| 333 |
+
parent.ismanual = dat["ismanual"]
|
| 334 |
+
|
| 335 |
+
if "current_channel" in dat:
|
| 336 |
+
parent.color = (dat["current_channel"] + 2) % 5
|
| 337 |
+
parent.RGBDropDown.setCurrentIndex(parent.color)
|
| 338 |
+
|
| 339 |
+
if "flows" in dat:
|
| 340 |
+
parent.flows = dat["flows"]
|
| 341 |
+
try:
|
| 342 |
+
if parent.flows[0].shape[-3] != dat["masks"].shape[-2]:
|
| 343 |
+
Ly, Lx = dat["masks"].shape[-2:]
|
| 344 |
+
for i in range(len(parent.flows)):
|
| 345 |
+
parent.flows[i] = cv2.resize(
|
| 346 |
+
parent.flows[i].squeeze(), (Lx, Ly),
|
| 347 |
+
interpolation=cv2.INTER_NEAREST)[np.newaxis, ...]
|
| 348 |
+
if parent.NZ == 1:
|
| 349 |
+
parent.recompute_masks = True
|
| 350 |
+
else:
|
| 351 |
+
parent.recompute_masks = False
|
| 352 |
+
|
| 353 |
+
except:
|
| 354 |
+
try:
|
| 355 |
+
if len(parent.flows[0]) > 0:
|
| 356 |
+
parent.flows = parent.flows[0]
|
| 357 |
+
except:
|
| 358 |
+
parent.flows = [[], [], [], [], [[]]]
|
| 359 |
+
parent.recompute_masks = False
|
| 360 |
+
|
| 361 |
+
parent.enable_buttons()
|
| 362 |
+
parent.update_layer()
|
| 363 |
+
del dat
|
| 364 |
+
gc.collect()
|
| 365 |
+
|
| 366 |
+
|
| 367 |
+
def _load_masks(parent, filename=None):
|
| 368 |
+
""" load zeros-based masks (0=no cell, 1=cell 1, ...) """
|
| 369 |
+
if filename is None:
|
| 370 |
+
name = QFileDialog.getOpenFileName(parent, "Load masks (PNG or TIFF)")
|
| 371 |
+
filename = name[0]
|
| 372 |
+
print(f"GUI_INFO: loading masks: {filename}")
|
| 373 |
+
masks = imread(filename)
|
| 374 |
+
outlines = None
|
| 375 |
+
if masks.ndim > 3:
|
| 376 |
+
# Z x nchannels x Ly x Lx
|
| 377 |
+
if masks.shape[-1] > 5:
|
| 378 |
+
parent.flows = list(np.transpose(masks[:, :, :, 2:], (3, 0, 1, 2)))
|
| 379 |
+
outlines = masks[..., 1]
|
| 380 |
+
masks = masks[..., 0]
|
| 381 |
+
else:
|
| 382 |
+
parent.flows = list(np.transpose(masks[:, :, :, 1:], (3, 0, 1, 2)))
|
| 383 |
+
masks = masks[..., 0]
|
| 384 |
+
elif masks.ndim == 3:
|
| 385 |
+
if masks.shape[-1] < 5:
|
| 386 |
+
masks = masks[np.newaxis, :, :, 0]
|
| 387 |
+
elif masks.ndim < 3:
|
| 388 |
+
masks = masks[np.newaxis, :, :]
|
| 389 |
+
# masks should be Z x Ly x Lx
|
| 390 |
+
if masks.shape[0] != parent.NZ:
|
| 391 |
+
print("ERROR: masks are not same depth (number of planes) as image stack")
|
| 392 |
+
return
|
| 393 |
+
|
| 394 |
+
_masks_to_gui(parent, masks, outlines)
|
| 395 |
+
if parent.ncells > 0:
|
| 396 |
+
parent.draw_layer()
|
| 397 |
+
parent.toggle_mask_ops()
|
| 398 |
+
del masks
|
| 399 |
+
gc.collect()
|
| 400 |
+
parent.update_layer()
|
| 401 |
+
parent.update_plot()
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
def _masks_to_gui(parent, masks, outlines=None, colors=None):
|
| 405 |
+
""" masks loaded into GUI """
|
| 406 |
+
# get unique values
|
| 407 |
+
shape = masks.shape
|
| 408 |
+
if len(fastremap.unique(masks)) != masks.max() + 1:
|
| 409 |
+
print("GUI_INFO: renumbering masks")
|
| 410 |
+
fastremap.renumber(masks, in_place=True)
|
| 411 |
+
outlines = None
|
| 412 |
+
masks = masks.reshape(shape)
|
| 413 |
+
if masks.ndim == 2:
|
| 414 |
+
outlines = None
|
| 415 |
+
masks = masks.astype(np.uint16) if masks.max() < 2**16 - 1 else masks.astype(
|
| 416 |
+
np.uint32)
|
| 417 |
+
if parent.restore and "upsample" in parent.restore:
|
| 418 |
+
parent.cellpix_resize = masks.copy()
|
| 419 |
+
parent.cellpix = parent.cellpix_resize.copy()
|
| 420 |
+
parent.cellpix_orig = cv2.resize(
|
| 421 |
+
masks.squeeze(), (parent.Lx0, parent.Ly0),
|
| 422 |
+
interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :]
|
| 423 |
+
parent.resize = True
|
| 424 |
+
else:
|
| 425 |
+
parent.cellpix = masks
|
| 426 |
+
if parent.cellpix.ndim == 2:
|
| 427 |
+
parent.cellpix = parent.cellpix[np.newaxis, :, :]
|
| 428 |
+
if parent.restore and "upsample" in parent.restore:
|
| 429 |
+
if parent.cellpix_resize.ndim == 2:
|
| 430 |
+
parent.cellpix_resize = parent.cellpix_resize[np.newaxis, :, :]
|
| 431 |
+
if parent.cellpix_orig.ndim == 2:
|
| 432 |
+
parent.cellpix_orig = parent.cellpix_orig[np.newaxis, :, :]
|
| 433 |
+
|
| 434 |
+
print(f"GUI_INFO: {masks.max()} masks found")
|
| 435 |
+
|
| 436 |
+
# get outlines
|
| 437 |
+
if outlines is None: # parent.outlinesOn
|
| 438 |
+
parent.outpix = np.zeros_like(parent.cellpix)
|
| 439 |
+
if parent.restore and "upsample" in parent.restore:
|
| 440 |
+
parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
|
| 441 |
+
for z in range(parent.NZ):
|
| 442 |
+
outlines = masks_to_outlines(parent.cellpix[z])
|
| 443 |
+
parent.outpix[z] = outlines * parent.cellpix[z]
|
| 444 |
+
if parent.restore and "upsample" in parent.restore:
|
| 445 |
+
outlines = masks_to_outlines(parent.cellpix_orig[z])
|
| 446 |
+
parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
|
| 447 |
+
if z % 50 == 0 and parent.NZ > 1:
|
| 448 |
+
print("GUI_INFO: plane %d outlines processed" % z)
|
| 449 |
+
if parent.restore and "upsample" in parent.restore:
|
| 450 |
+
parent.outpix_resize = parent.outpix.copy()
|
| 451 |
+
else:
|
| 452 |
+
parent.outpix = outlines
|
| 453 |
+
if parent.restore and "upsample" in parent.restore:
|
| 454 |
+
parent.outpix_resize = parent.outpix.copy()
|
| 455 |
+
parent.outpix_orig = np.zeros_like(parent.cellpix_orig)
|
| 456 |
+
for z in range(parent.NZ):
|
| 457 |
+
outlines = masks_to_outlines(parent.cellpix_orig[z])
|
| 458 |
+
parent.outpix_orig[z] = outlines * parent.cellpix_orig[z]
|
| 459 |
+
if z % 50 == 0 and parent.NZ > 1:
|
| 460 |
+
print("GUI_INFO: plane %d outlines processed" % z)
|
| 461 |
+
|
| 462 |
+
if parent.outpix.ndim == 2:
|
| 463 |
+
parent.outpix = parent.outpix[np.newaxis, :, :]
|
| 464 |
+
if parent.restore and "upsample" in parent.restore:
|
| 465 |
+
if parent.outpix_resize.ndim == 2:
|
| 466 |
+
parent.outpix_resize = parent.outpix_resize[np.newaxis, :, :]
|
| 467 |
+
if parent.outpix_orig.ndim == 2:
|
| 468 |
+
parent.outpix_orig = parent.outpix_orig[np.newaxis, :, :]
|
| 469 |
+
|
| 470 |
+
parent.ncells.set(parent.cellpix.max())
|
| 471 |
+
colors = parent.colormap[:parent.ncells.get(), :3] if colors is None else colors
|
| 472 |
+
print("GUI_INFO: creating cellcolors and drawing masks")
|
| 473 |
+
parent.cellcolors = np.concatenate((np.array([[255, 255, 255]]), colors),
|
| 474 |
+
axis=0).astype(np.uint8)
|
| 475 |
+
if parent.ncells > 0:
|
| 476 |
+
parent.draw_layer()
|
| 477 |
+
parent.toggle_mask_ops()
|
| 478 |
+
parent.ismanual = np.zeros(parent.ncells.get(), bool)
|
| 479 |
+
parent.zdraw = list(-1 * np.ones(parent.ncells.get(), np.int16))
|
| 480 |
+
|
| 481 |
+
if hasattr(parent, "stack_filtered"):
|
| 482 |
+
parent.ViewDropDown.setCurrentIndex(parent.ViewDropDown.count() - 1)
|
| 483 |
+
print("set denoised/filtered view")
|
| 484 |
+
else:
|
| 485 |
+
parent.ViewDropDown.setCurrentIndex(0)
|
| 486 |
+
|
| 487 |
+
|
| 488 |
+
def _save_png(parent):
|
| 489 |
+
""" save masks to png or tiff (if 3D) """
|
| 490 |
+
filename = parent.filename
|
| 491 |
+
base = os.path.splitext(filename)[0]
|
| 492 |
+
if parent.NZ == 1:
|
| 493 |
+
if parent.cellpix[0].max() > 65534:
|
| 494 |
+
print("GUI_INFO: saving 2D masks to tif (too many masks for PNG)")
|
| 495 |
+
imsave(base + "_cp_masks.tif", parent.cellpix[0])
|
| 496 |
+
else:
|
| 497 |
+
print("GUI_INFO: saving 2D masks to png")
|
| 498 |
+
imsave(base + "_cp_masks.png", parent.cellpix[0].astype(np.uint16))
|
| 499 |
+
else:
|
| 500 |
+
print("GUI_INFO: saving 3D masks to tiff")
|
| 501 |
+
imsave(base + "_cp_masks.tif", parent.cellpix)
|
| 502 |
+
|
| 503 |
+
|
| 504 |
+
def _save_flows(parent):
|
| 505 |
+
""" save flows and cellprob to tiff """
|
| 506 |
+
filename = parent.filename
|
| 507 |
+
base = os.path.splitext(filename)[0]
|
| 508 |
+
print("GUI_INFO: saving flows and cellprob to tiff")
|
| 509 |
+
if len(parent.flows) > 0:
|
| 510 |
+
imsave(base + "_cp_cellprob.tif", parent.flows[1])
|
| 511 |
+
for i in range(3):
|
| 512 |
+
imsave(base + f"_cp_flows_{i}.tif", parent.flows[0][..., i])
|
| 513 |
+
if len(parent.flows) > 2:
|
| 514 |
+
imsave(base + "_cp_flows.tif", parent.flows[2])
|
| 515 |
+
print("GUI_INFO: saved flows and cellprob")
|
| 516 |
+
else:
|
| 517 |
+
print("ERROR: no flows or cellprob found")
|
| 518 |
+
|
| 519 |
+
|
| 520 |
+
def _save_rois(parent):
|
| 521 |
+
""" save masks as rois in .zip file for ImageJ """
|
| 522 |
+
filename = parent.filename
|
| 523 |
+
if parent.NZ == 1:
|
| 524 |
+
print(
|
| 525 |
+
f"GUI_INFO: saving {parent.cellpix[0].max()} ImageJ ROIs to .zip archive.")
|
| 526 |
+
save_rois(parent.cellpix[0], parent.filename)
|
| 527 |
+
else:
|
| 528 |
+
print("ERROR: cannot save 3D outlines")
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def _save_outlines(parent):
|
| 532 |
+
filename = parent.filename
|
| 533 |
+
base = os.path.splitext(filename)[0]
|
| 534 |
+
if parent.NZ == 1:
|
| 535 |
+
print(
|
| 536 |
+
"GUI_INFO: saving 2D outlines to text file, see docs for info to load into ImageJ"
|
| 537 |
+
)
|
| 538 |
+
outlines = outlines_list(parent.cellpix[0])
|
| 539 |
+
outlines_to_text(base, outlines)
|
| 540 |
+
else:
|
| 541 |
+
print("ERROR: cannot save 3D outlines")
|
| 542 |
+
|
| 543 |
+
|
| 544 |
+
def _save_sets_with_check(parent):
|
| 545 |
+
""" Save masks and update *_seg.npy file. Use this function when saving should be optional
|
| 546 |
+
based on the disableAutosave checkbox. Otherwise, use _save_sets """
|
| 547 |
+
if not parent.disableAutosave.isChecked():
|
| 548 |
+
_save_sets(parent)
|
| 549 |
+
|
| 550 |
+
|
| 551 |
+
def _save_sets(parent):
|
| 552 |
+
""" save masks to *_seg.npy. This function should be used when saving
|
| 553 |
+
is forced, e.g. when clicking the save button. Otherwise, use _save_sets_with_check
|
| 554 |
+
"""
|
| 555 |
+
filename = parent.filename
|
| 556 |
+
base = os.path.splitext(filename)[0]
|
| 557 |
+
flow_threshold = parent.segmentation_settings.flow_threshold
|
| 558 |
+
cellprob_threshold = parent.segmentation_settings.cellprob_threshold
|
| 559 |
+
|
| 560 |
+
if parent.NZ > 1:
|
| 561 |
+
dat = {
|
| 562 |
+
"outlines":
|
| 563 |
+
parent.outpix,
|
| 564 |
+
"colors":
|
| 565 |
+
parent.cellcolors[1:],
|
| 566 |
+
"masks":
|
| 567 |
+
parent.cellpix,
|
| 568 |
+
"current_channel": (parent.color - 2) % 5,
|
| 569 |
+
"filename":
|
| 570 |
+
parent.filename,
|
| 571 |
+
"flows":
|
| 572 |
+
parent.flows,
|
| 573 |
+
"zdraw":
|
| 574 |
+
parent.zdraw,
|
| 575 |
+
"model_path":
|
| 576 |
+
parent.current_model_path
|
| 577 |
+
if hasattr(parent, "current_model_path") else 0,
|
| 578 |
+
"flow_threshold":
|
| 579 |
+
flow_threshold,
|
| 580 |
+
"cellprob_threshold":
|
| 581 |
+
cellprob_threshold,
|
| 582 |
+
"normalize_params":
|
| 583 |
+
parent.get_normalize_params(),
|
| 584 |
+
"restore":
|
| 585 |
+
parent.restore,
|
| 586 |
+
"ratio":
|
| 587 |
+
parent.ratio,
|
| 588 |
+
"diameter":
|
| 589 |
+
parent.segmentation_settings.diameter
|
| 590 |
+
}
|
| 591 |
+
if parent.restore is not None:
|
| 592 |
+
dat["img_restore"] = parent.stack_filtered
|
| 593 |
+
else:
|
| 594 |
+
dat = {
|
| 595 |
+
"outlines":
|
| 596 |
+
parent.outpix.squeeze() if parent.restore is None or
|
| 597 |
+
not "upsample" in parent.restore else parent.outpix_resize.squeeze(),
|
| 598 |
+
"colors":
|
| 599 |
+
parent.cellcolors[1:],
|
| 600 |
+
"masks":
|
| 601 |
+
parent.cellpix.squeeze() if parent.restore is None or
|
| 602 |
+
not "upsample" in parent.restore else parent.cellpix_resize.squeeze(),
|
| 603 |
+
"filename":
|
| 604 |
+
parent.filename,
|
| 605 |
+
"flows":
|
| 606 |
+
parent.flows,
|
| 607 |
+
"ismanual":
|
| 608 |
+
parent.ismanual,
|
| 609 |
+
"manual_changes":
|
| 610 |
+
parent.track_changes,
|
| 611 |
+
"model_path":
|
| 612 |
+
parent.current_model_path
|
| 613 |
+
if hasattr(parent, "current_model_path") else 0,
|
| 614 |
+
"flow_threshold":
|
| 615 |
+
flow_threshold,
|
| 616 |
+
"cellprob_threshold":
|
| 617 |
+
cellprob_threshold,
|
| 618 |
+
"normalize_params":
|
| 619 |
+
parent.get_normalize_params(),
|
| 620 |
+
"restore":
|
| 621 |
+
parent.restore,
|
| 622 |
+
"ratio":
|
| 623 |
+
parent.ratio,
|
| 624 |
+
"diameter":
|
| 625 |
+
parent.segmentation_settings.diameter
|
| 626 |
+
}
|
| 627 |
+
if parent.restore is not None:
|
| 628 |
+
dat["img_restore"] = parent.stack_filtered
|
| 629 |
+
try:
|
| 630 |
+
np.save(base + "_seg.npy", dat)
|
| 631 |
+
print("GUI_INFO: %d ROIs saved to %s" % (parent.ncells.get(), base + "_seg.npy"))
|
| 632 |
+
except Exception as e:
|
| 633 |
+
print(f"ERROR: {e}")
|
| 634 |
+
del dat
|
models/seg_post_model/cellpose/gui/make_train.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os, argparse
|
| 2 |
+
import numpy as np
|
| 3 |
+
from cellpose import io, transforms
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
parser = argparse.ArgumentParser(description='Make slices of XYZ image data for training. Assumes image is ZXYC unless specified otherwise using --channel_axis and --z_axis')
|
| 8 |
+
|
| 9 |
+
input_img_args = parser.add_argument_group("input image arguments")
|
| 10 |
+
input_img_args.add_argument('--dir', default=[], type=str,
|
| 11 |
+
help='folder containing data to run or train on.')
|
| 12 |
+
input_img_args.add_argument(
|
| 13 |
+
'--image_path', default=[], type=str, help=
|
| 14 |
+
'if given and --dir not given, run on single image instead of folder (cannot train with this option)'
|
| 15 |
+
)
|
| 16 |
+
input_img_args.add_argument(
|
| 17 |
+
'--look_one_level_down', action='store_true',
|
| 18 |
+
help='run processing on all subdirectories of current folder')
|
| 19 |
+
input_img_args.add_argument('--img_filter', default=[], type=str,
|
| 20 |
+
help='end string for images to run on')
|
| 21 |
+
input_img_args.add_argument(
|
| 22 |
+
'--channel_axis', default=-1, type=int,
|
| 23 |
+
help='axis of image which corresponds to image channels')
|
| 24 |
+
input_img_args.add_argument('--z_axis', default=0, type=int,
|
| 25 |
+
help='axis of image which corresponds to Z dimension')
|
| 26 |
+
input_img_args.add_argument(
|
| 27 |
+
'--chan', default=0, type=int, help=
|
| 28 |
+
'Deprecated')
|
| 29 |
+
input_img_args.add_argument(
|
| 30 |
+
'--chan2', default=0, type=int, help=
|
| 31 |
+
'Deprecated'
|
| 32 |
+
)
|
| 33 |
+
input_img_args.add_argument('--invert', action='store_true',
|
| 34 |
+
help='invert grayscale channel')
|
| 35 |
+
input_img_args.add_argument(
|
| 36 |
+
'--all_channels', action='store_true', help=
|
| 37 |
+
'deprecated')
|
| 38 |
+
input_img_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
|
| 39 |
+
help="anisotropy of volume in 3D")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
# algorithm settings
|
| 43 |
+
algorithm_args = parser.add_argument_group("algorithm arguments")
|
| 44 |
+
algorithm_args.add_argument('--sharpen_radius', required=False, default=0.0,
|
| 45 |
+
type=float, help='high-pass filtering radius. Default: %(default)s')
|
| 46 |
+
algorithm_args.add_argument('--tile_norm', required=False, default=0, type=int,
|
| 47 |
+
help='tile normalization block size. Default: %(default)s')
|
| 48 |
+
algorithm_args.add_argument('--nimg_per_tif', required=False, default=10, type=int,
|
| 49 |
+
help='number of crops in XY to save per tiff. Default: %(default)s')
|
| 50 |
+
algorithm_args.add_argument('--crop_size', required=False, default=512, type=int,
|
| 51 |
+
help='size of random crop to save. Default: %(default)s')
|
| 52 |
+
|
| 53 |
+
args = parser.parse_args()
|
| 54 |
+
|
| 55 |
+
# find images
|
| 56 |
+
if len(args.img_filter) > 0:
|
| 57 |
+
imf = args.img_filter
|
| 58 |
+
else:
|
| 59 |
+
imf = None
|
| 60 |
+
|
| 61 |
+
if len(args.dir) > 0:
|
| 62 |
+
image_names = io.get_image_files(args.dir, "_masks", imf=imf,
|
| 63 |
+
look_one_level_down=args.look_one_level_down)
|
| 64 |
+
dirname = args.dir
|
| 65 |
+
else:
|
| 66 |
+
if os.path.exists(args.image_path):
|
| 67 |
+
image_names = [args.image_path]
|
| 68 |
+
dirname = os.path.split(args.image_path)[0]
|
| 69 |
+
else:
|
| 70 |
+
raise ValueError(f"ERROR: no file found at {args.image_path}")
|
| 71 |
+
|
| 72 |
+
np.random.seed(0)
|
| 73 |
+
nimg_per_tif = args.nimg_per_tif
|
| 74 |
+
crop_size = args.crop_size
|
| 75 |
+
os.makedirs(os.path.join(dirname, 'train/'), exist_ok=True)
|
| 76 |
+
pm = [(0, 1, 2, 3), (2, 0, 1, 3), (1, 0, 2, 3)]
|
| 77 |
+
npm = ["YX", "ZY", "ZX"]
|
| 78 |
+
for name in image_names:
|
| 79 |
+
name0 = os.path.splitext(os.path.split(name)[-1])[0]
|
| 80 |
+
img0 = io.imread_3D(name)
|
| 81 |
+
try:
|
| 82 |
+
img0 = transforms.convert_image(img0, channel_axis=args.channel_axis,
|
| 83 |
+
z_axis=args.z_axis, do_3D=True)
|
| 84 |
+
except ValueError:
|
| 85 |
+
print('Error converting image. Did you provide the correct --channel_axis and --z_axis ?')
|
| 86 |
+
|
| 87 |
+
for p in range(3):
|
| 88 |
+
img = img0.transpose(pm[p]).copy()
|
| 89 |
+
print(npm[p], img[0].shape)
|
| 90 |
+
Ly, Lx = img.shape[1:3]
|
| 91 |
+
imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]]
|
| 92 |
+
if args.anisotropy > 1.0 and p > 0:
|
| 93 |
+
imgs = transforms.resize_image(imgs, Ly=int(args.anisotropy * Ly), Lx=Lx)
|
| 94 |
+
for k, img in enumerate(imgs):
|
| 95 |
+
if args.tile_norm:
|
| 96 |
+
img = transforms.normalize99_tile(img, blocksize=args.tile_norm)
|
| 97 |
+
if args.sharpen_radius:
|
| 98 |
+
img = transforms.smooth_sharpen_img(img,
|
| 99 |
+
sharpen_radius=args.sharpen_radius)
|
| 100 |
+
ly = 0 if Ly - crop_size <= 0 else np.random.randint(0, Ly - crop_size)
|
| 101 |
+
lx = 0 if Lx - crop_size <= 0 else np.random.randint(0, Lx - crop_size)
|
| 102 |
+
io.imsave(os.path.join(dirname, f'train/{name0}_{npm[p]}_{k}.tif'),
|
| 103 |
+
img[ly:ly + args.crop_size, lx:lx + args.crop_size].squeeze())
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
if __name__ == '__main__':
|
| 107 |
+
main()
|
models/seg_post_model/cellpose/gui/menus.py
ADDED
|
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
from qtpy.QtWidgets import QAction
|
| 5 |
+
from . import io
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def mainmenu(parent):
|
| 9 |
+
main_menu = parent.menuBar()
|
| 10 |
+
file_menu = main_menu.addMenu("&File")
|
| 11 |
+
# load processed data
|
| 12 |
+
loadImg = QAction("&Load image (*.tif, *.png, *.jpg)", parent)
|
| 13 |
+
loadImg.setShortcut("Ctrl+L")
|
| 14 |
+
loadImg.triggered.connect(lambda: io._load_image(parent))
|
| 15 |
+
file_menu.addAction(loadImg)
|
| 16 |
+
|
| 17 |
+
parent.autoloadMasks = QAction("Autoload masks from _masks.tif file", parent,
|
| 18 |
+
checkable=True)
|
| 19 |
+
parent.autoloadMasks.setChecked(False)
|
| 20 |
+
file_menu.addAction(parent.autoloadMasks)
|
| 21 |
+
|
| 22 |
+
parent.disableAutosave = QAction("Disable autosave _seg.npy file", parent,
|
| 23 |
+
checkable=True)
|
| 24 |
+
parent.disableAutosave.setChecked(False)
|
| 25 |
+
file_menu.addAction(parent.disableAutosave)
|
| 26 |
+
|
| 27 |
+
parent.loadMasks = QAction("Load &masks (*.tif, *.png, *.jpg)", parent)
|
| 28 |
+
parent.loadMasks.setShortcut("Ctrl+M")
|
| 29 |
+
parent.loadMasks.triggered.connect(lambda: io._load_masks(parent))
|
| 30 |
+
file_menu.addAction(parent.loadMasks)
|
| 31 |
+
parent.loadMasks.setEnabled(False)
|
| 32 |
+
|
| 33 |
+
loadManual = QAction("Load &processed/labelled image (*_seg.npy)", parent)
|
| 34 |
+
loadManual.setShortcut("Ctrl+P")
|
| 35 |
+
loadManual.triggered.connect(lambda: io._load_seg(parent))
|
| 36 |
+
file_menu.addAction(loadManual)
|
| 37 |
+
|
| 38 |
+
parent.saveSet = QAction("&Save masks and image (as *_seg.npy)", parent)
|
| 39 |
+
parent.saveSet.setShortcut("Ctrl+S")
|
| 40 |
+
parent.saveSet.triggered.connect(lambda: io._save_sets(parent))
|
| 41 |
+
file_menu.addAction(parent.saveSet)
|
| 42 |
+
parent.saveSet.setEnabled(False)
|
| 43 |
+
|
| 44 |
+
parent.savePNG = QAction("Save masks as P&NG/tif", parent)
|
| 45 |
+
parent.savePNG.setShortcut("Ctrl+N")
|
| 46 |
+
parent.savePNG.triggered.connect(lambda: io._save_png(parent))
|
| 47 |
+
file_menu.addAction(parent.savePNG)
|
| 48 |
+
parent.savePNG.setEnabled(False)
|
| 49 |
+
|
| 50 |
+
parent.saveOutlines = QAction("Save &Outlines as text for imageJ", parent)
|
| 51 |
+
parent.saveOutlines.setShortcut("Ctrl+O")
|
| 52 |
+
parent.saveOutlines.triggered.connect(lambda: io._save_outlines(parent))
|
| 53 |
+
file_menu.addAction(parent.saveOutlines)
|
| 54 |
+
parent.saveOutlines.setEnabled(False)
|
| 55 |
+
|
| 56 |
+
parent.saveROIs = QAction("Save outlines as .zip archive of &ROI files for ImageJ",
|
| 57 |
+
parent)
|
| 58 |
+
parent.saveROIs.setShortcut("Ctrl+R")
|
| 59 |
+
parent.saveROIs.triggered.connect(lambda: io._save_rois(parent))
|
| 60 |
+
file_menu.addAction(parent.saveROIs)
|
| 61 |
+
parent.saveROIs.setEnabled(False)
|
| 62 |
+
|
| 63 |
+
parent.saveFlows = QAction("Save &Flows and cellprob as tif", parent)
|
| 64 |
+
parent.saveFlows.setShortcut("Ctrl+F")
|
| 65 |
+
parent.saveFlows.triggered.connect(lambda: io._save_flows(parent))
|
| 66 |
+
file_menu.addAction(parent.saveFlows)
|
| 67 |
+
parent.saveFlows.setEnabled(False)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def editmenu(parent):
|
| 71 |
+
main_menu = parent.menuBar()
|
| 72 |
+
edit_menu = main_menu.addMenu("&Edit")
|
| 73 |
+
parent.undo = QAction("Undo previous mask/trace", parent)
|
| 74 |
+
parent.undo.setShortcut("Ctrl+Z")
|
| 75 |
+
parent.undo.triggered.connect(parent.undo_action)
|
| 76 |
+
parent.undo.setEnabled(False)
|
| 77 |
+
edit_menu.addAction(parent.undo)
|
| 78 |
+
|
| 79 |
+
parent.redo = QAction("Undo remove mask", parent)
|
| 80 |
+
parent.redo.setShortcut("Ctrl+Y")
|
| 81 |
+
parent.redo.triggered.connect(parent.undo_remove_action)
|
| 82 |
+
parent.redo.setEnabled(False)
|
| 83 |
+
edit_menu.addAction(parent.redo)
|
| 84 |
+
|
| 85 |
+
parent.ClearButton = QAction("Clear all masks", parent)
|
| 86 |
+
parent.ClearButton.setShortcut("Ctrl+0")
|
| 87 |
+
parent.ClearButton.triggered.connect(parent.clear_all)
|
| 88 |
+
parent.ClearButton.setEnabled(False)
|
| 89 |
+
edit_menu.addAction(parent.ClearButton)
|
| 90 |
+
|
| 91 |
+
parent.remcell = QAction("Remove selected cell (Ctrl+CLICK)", parent)
|
| 92 |
+
parent.remcell.setShortcut("Ctrl+Click")
|
| 93 |
+
parent.remcell.triggered.connect(parent.remove_action)
|
| 94 |
+
parent.remcell.setEnabled(False)
|
| 95 |
+
edit_menu.addAction(parent.remcell)
|
| 96 |
+
|
| 97 |
+
parent.mergecell = QAction("FYI: Merge cells by Alt+Click", parent)
|
| 98 |
+
parent.mergecell.setEnabled(False)
|
| 99 |
+
edit_menu.addAction(parent.mergecell)
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def modelmenu(parent):
|
| 103 |
+
main_menu = parent.menuBar()
|
| 104 |
+
io._init_model_list(parent)
|
| 105 |
+
model_menu = main_menu.addMenu("&Models")
|
| 106 |
+
parent.addmodel = QAction("Add custom torch model to GUI", parent)
|
| 107 |
+
#parent.addmodel.setShortcut("Ctrl+A")
|
| 108 |
+
parent.addmodel.triggered.connect(parent.add_model)
|
| 109 |
+
parent.addmodel.setEnabled(True)
|
| 110 |
+
model_menu.addAction(parent.addmodel)
|
| 111 |
+
|
| 112 |
+
parent.removemodel = QAction("Remove selected custom model from GUI", parent)
|
| 113 |
+
#parent.removemodel.setShortcut("Ctrl+R")
|
| 114 |
+
parent.removemodel.triggered.connect(parent.remove_model)
|
| 115 |
+
parent.removemodel.setEnabled(True)
|
| 116 |
+
model_menu.addAction(parent.removemodel)
|
| 117 |
+
|
| 118 |
+
parent.newmodel = QAction("&Train new model with image+masks in folder", parent)
|
| 119 |
+
parent.newmodel.setShortcut("Ctrl+T")
|
| 120 |
+
parent.newmodel.triggered.connect(parent.new_model)
|
| 121 |
+
parent.newmodel.setEnabled(False)
|
| 122 |
+
model_menu.addAction(parent.newmodel)
|
| 123 |
+
|
| 124 |
+
openTrainHelp = QAction("Training instructions", parent)
|
| 125 |
+
openTrainHelp.triggered.connect(parent.train_help_window)
|
| 126 |
+
model_menu.addAction(openTrainHelp)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def helpmenu(parent):
|
| 130 |
+
main_menu = parent.menuBar()
|
| 131 |
+
help_menu = main_menu.addMenu("&Help")
|
| 132 |
+
|
| 133 |
+
openHelp = QAction("&Help with GUI", parent)
|
| 134 |
+
openHelp.setShortcut("Ctrl+H")
|
| 135 |
+
openHelp.triggered.connect(parent.help_window)
|
| 136 |
+
help_menu.addAction(openHelp)
|
| 137 |
+
|
| 138 |
+
openGUI = QAction("&GUI layout", parent)
|
| 139 |
+
openGUI.setShortcut("Ctrl+G")
|
| 140 |
+
openGUI.triggered.connect(parent.gui_window)
|
| 141 |
+
help_menu.addAction(openGUI)
|
| 142 |
+
|
| 143 |
+
openTrainHelp = QAction("Training instructions", parent)
|
| 144 |
+
openTrainHelp.triggered.connect(parent.train_help_window)
|
| 145 |
+
help_menu.addAction(openTrainHelp)
|
models/seg_post_model/cellpose/io.py
ADDED
|
@@ -0,0 +1,816 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import os, warnings, glob, shutil
|
| 5 |
+
from natsort import natsorted
|
| 6 |
+
import numpy as np
|
| 7 |
+
import cv2
|
| 8 |
+
import tifffile
|
| 9 |
+
import logging, pathlib, sys
|
| 10 |
+
from tqdm import tqdm
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import re
|
| 13 |
+
from .version import version_str
|
| 14 |
+
from roifile import ImagejRoi, roiwrite
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from qtpy import QtGui, QtCore, Qt, QtWidgets
|
| 18 |
+
from qtpy.QtWidgets import QMessageBox
|
| 19 |
+
GUI = True
|
| 20 |
+
except:
|
| 21 |
+
GUI = False
|
| 22 |
+
|
| 23 |
+
try:
|
| 24 |
+
import matplotlib.pyplot as plt
|
| 25 |
+
MATPLOTLIB = True
|
| 26 |
+
except:
|
| 27 |
+
MATPLOTLIB = False
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
import nd2
|
| 31 |
+
ND2 = True
|
| 32 |
+
except:
|
| 33 |
+
ND2 = False
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import nrrd
|
| 37 |
+
NRRD = True
|
| 38 |
+
except:
|
| 39 |
+
NRRD = False
|
| 40 |
+
|
| 41 |
+
try:
|
| 42 |
+
from google.cloud import storage
|
| 43 |
+
SERVER_UPLOAD = True
|
| 44 |
+
except:
|
| 45 |
+
SERVER_UPLOAD = False
|
| 46 |
+
|
| 47 |
+
io_logger = logging.getLogger(__name__)
|
| 48 |
+
|
| 49 |
+
def logger_setup(cp_path=".cellpose", logfile_name="run.log", stdout_file_replacement=None):
|
| 50 |
+
cp_dir = pathlib.Path.home().joinpath(cp_path)
|
| 51 |
+
cp_dir.mkdir(exist_ok=True)
|
| 52 |
+
log_file = cp_dir.joinpath(logfile_name)
|
| 53 |
+
try:
|
| 54 |
+
log_file.unlink()
|
| 55 |
+
except:
|
| 56 |
+
print('creating new log file')
|
| 57 |
+
handlers = [logging.FileHandler(log_file),]
|
| 58 |
+
if stdout_file_replacement is not None:
|
| 59 |
+
handlers.append(logging.FileHandler(stdout_file_replacement))
|
| 60 |
+
else:
|
| 61 |
+
handlers.append(logging.StreamHandler(sys.stdout))
|
| 62 |
+
logging.basicConfig(
|
| 63 |
+
level=logging.INFO,
|
| 64 |
+
format="%(asctime)s [%(levelname)s] %(message)s",
|
| 65 |
+
handlers=handlers,
|
| 66 |
+
force=True
|
| 67 |
+
)
|
| 68 |
+
logger = logging.getLogger(__name__)
|
| 69 |
+
logger.info(f"WRITING LOG OUTPUT TO {log_file}")
|
| 70 |
+
logger.info(version_str)
|
| 71 |
+
|
| 72 |
+
return logger, log_file
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
from . import utils, plot, transforms
|
| 76 |
+
|
| 77 |
+
# helper function to check for a path; if it doesn't exist, make it
|
| 78 |
+
def check_dir(path):
|
| 79 |
+
if not os.path.isdir(path):
|
| 80 |
+
os.mkdir(path)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def outlines_to_text(base, outlines):
|
| 84 |
+
with open(base + "_cp_outlines.txt", "w") as f:
|
| 85 |
+
for o in outlines:
|
| 86 |
+
xy = list(o.flatten())
|
| 87 |
+
xy_str = ",".join(map(str, xy))
|
| 88 |
+
f.write(xy_str)
|
| 89 |
+
f.write("\n")
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
def load_dax(filename):
|
| 93 |
+
### modified from ZhuangLab github:
|
| 94 |
+
### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156
|
| 95 |
+
|
| 96 |
+
inf_filename = os.path.splitext(filename)[0] + ".inf"
|
| 97 |
+
if not os.path.exists(inf_filename):
|
| 98 |
+
io_logger.critical(
|
| 99 |
+
f"ERROR: no inf file found for dax file {filename}, cannot load dax without it"
|
| 100 |
+
)
|
| 101 |
+
return None
|
| 102 |
+
|
| 103 |
+
### get metadata
|
| 104 |
+
image_height, image_width = None, None
|
| 105 |
+
# extract the movie information from the associated inf file
|
| 106 |
+
size_re = re.compile(r"frame dimensions = ([\d]+) x ([\d]+)")
|
| 107 |
+
length_re = re.compile(r"number of frames = ([\d]+)")
|
| 108 |
+
endian_re = re.compile(r" (big|little) endian")
|
| 109 |
+
|
| 110 |
+
with open(inf_filename, "r") as inf_file:
|
| 111 |
+
lines = inf_file.read().split("\n")
|
| 112 |
+
for line in lines:
|
| 113 |
+
m = size_re.match(line)
|
| 114 |
+
if m:
|
| 115 |
+
image_height = int(m.group(2))
|
| 116 |
+
image_width = int(m.group(1))
|
| 117 |
+
m = length_re.match(line)
|
| 118 |
+
if m:
|
| 119 |
+
number_frames = int(m.group(1))
|
| 120 |
+
m = endian_re.search(line)
|
| 121 |
+
if m:
|
| 122 |
+
if m.group(1) == "big":
|
| 123 |
+
bigendian = 1
|
| 124 |
+
else:
|
| 125 |
+
bigendian = 0
|
| 126 |
+
# set defaults, warn the user that they couldn"t be determined from the inf file.
|
| 127 |
+
if not image_height:
|
| 128 |
+
io_logger.warning("could not determine dax image size, assuming 256x256")
|
| 129 |
+
image_height = 256
|
| 130 |
+
image_width = 256
|
| 131 |
+
|
| 132 |
+
### load image
|
| 133 |
+
img = np.memmap(filename, dtype="uint16",
|
| 134 |
+
shape=(number_frames, image_height, image_width))
|
| 135 |
+
if bigendian:
|
| 136 |
+
img = img.byteswap()
|
| 137 |
+
img = np.array(img)
|
| 138 |
+
|
| 139 |
+
return img
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def imread(filename):
|
| 143 |
+
"""
|
| 144 |
+
Read in an image file with tif or image file type supported by cv2.
|
| 145 |
+
|
| 146 |
+
Args:
|
| 147 |
+
filename (str): The path to the image file.
|
| 148 |
+
|
| 149 |
+
Returns:
|
| 150 |
+
numpy.ndarray: The image data as a NumPy array.
|
| 151 |
+
|
| 152 |
+
Raises:
|
| 153 |
+
None
|
| 154 |
+
|
| 155 |
+
Raises an error if the image file format is not supported.
|
| 156 |
+
|
| 157 |
+
Examples:
|
| 158 |
+
>>> img = imread("image.tif")
|
| 159 |
+
"""
|
| 160 |
+
# ensure that extension check is not case sensitive
|
| 161 |
+
ext = os.path.splitext(filename)[-1].lower()
|
| 162 |
+
if ext == ".tif" or ext == ".tiff" or ext == ".flex":
|
| 163 |
+
with tifffile.TiffFile(filename) as tif:
|
| 164 |
+
ltif = len(tif.pages)
|
| 165 |
+
try:
|
| 166 |
+
full_shape = tif.shaped_metadata[0]["shape"]
|
| 167 |
+
except:
|
| 168 |
+
try:
|
| 169 |
+
page = tif.series[0][0]
|
| 170 |
+
full_shape = tif.series[0].shape
|
| 171 |
+
except:
|
| 172 |
+
ltif = 0
|
| 173 |
+
if ltif < 10:
|
| 174 |
+
img = tif.asarray()
|
| 175 |
+
else:
|
| 176 |
+
page = tif.series[0][0]
|
| 177 |
+
shape, dtype = page.shape, page.dtype
|
| 178 |
+
ltif = int(np.prod(full_shape) / np.prod(shape))
|
| 179 |
+
io_logger.info(f"reading tiff with {ltif} planes")
|
| 180 |
+
img = np.zeros((ltif, *shape), dtype=dtype)
|
| 181 |
+
for i, page in enumerate(tqdm(tif.series[0])):
|
| 182 |
+
img[i] = page.asarray()
|
| 183 |
+
img = img.reshape(full_shape)
|
| 184 |
+
return img
|
| 185 |
+
elif ext == ".dax":
|
| 186 |
+
img = load_dax(filename)
|
| 187 |
+
return img
|
| 188 |
+
elif ext == ".nd2":
|
| 189 |
+
if not ND2:
|
| 190 |
+
io_logger.critical("ERROR: need to 'pip install nd2' to load in .nd2 file")
|
| 191 |
+
return None
|
| 192 |
+
elif ext == ".nrrd":
|
| 193 |
+
if not NRRD:
|
| 194 |
+
io_logger.critical(
|
| 195 |
+
"ERROR: need to 'pip install pynrrd' to load in .nrrd file")
|
| 196 |
+
return None
|
| 197 |
+
else:
|
| 198 |
+
img, metadata = nrrd.read(filename)
|
| 199 |
+
if img.ndim == 3:
|
| 200 |
+
img = img.transpose(2, 0, 1)
|
| 201 |
+
return img
|
| 202 |
+
elif ext != ".npy":
|
| 203 |
+
try:
|
| 204 |
+
img = cv2.imread(filename, -1) #cv2.LOAD_IMAGE_ANYDEPTH)
|
| 205 |
+
if img.ndim > 2:
|
| 206 |
+
img = img[..., [2, 1, 0]]
|
| 207 |
+
return img
|
| 208 |
+
except Exception as e:
|
| 209 |
+
io_logger.critical("ERROR: could not read file, %s" % e)
|
| 210 |
+
return None
|
| 211 |
+
else:
|
| 212 |
+
try:
|
| 213 |
+
dat = np.load(filename, allow_pickle=True).item()
|
| 214 |
+
masks = dat["masks"]
|
| 215 |
+
return masks
|
| 216 |
+
except Exception as e:
|
| 217 |
+
io_logger.critical("ERROR: could not read masks from file, %s" % e)
|
| 218 |
+
return None
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def imread_2D(img_file):
|
| 222 |
+
"""
|
| 223 |
+
Read in a 2D image file and convert it to a 3-channel image. Attempts to do this for multi-channel and grayscale images.
|
| 224 |
+
If the image has more than 3 channels, only the first 3 channels are kept.
|
| 225 |
+
|
| 226 |
+
Args:
|
| 227 |
+
img_file (str): The path to the image file.
|
| 228 |
+
|
| 229 |
+
Returns:
|
| 230 |
+
img_out (numpy.ndarray): The 3-channel image data as a NumPy array.
|
| 231 |
+
"""
|
| 232 |
+
img = imread(img_file)
|
| 233 |
+
return transforms.convert_image(img, do_3D=False)
|
| 234 |
+
|
| 235 |
+
|
| 236 |
+
def imread_3D(img_file):
|
| 237 |
+
"""
|
| 238 |
+
Read in a 3D image file and convert it to have a channel axis last automatically. Attempts to do this for multi-channel and grayscale images.
|
| 239 |
+
|
| 240 |
+
If multichannel image, the channel axis is assumed to be the smallest dimension, and the z axis is the next smallest dimension.
|
| 241 |
+
Use `cellpose.io.imread()` to load the full image without selecting the z and channel axes.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
img_file (str): The path to the image file.
|
| 245 |
+
|
| 246 |
+
Returns:
|
| 247 |
+
img_out (numpy.ndarray): The image data as a NumPy array.
|
| 248 |
+
"""
|
| 249 |
+
img = imread(img_file)
|
| 250 |
+
|
| 251 |
+
dimension_lengths = list(img.shape)
|
| 252 |
+
|
| 253 |
+
# grayscale images:
|
| 254 |
+
if img.ndim == 3:
|
| 255 |
+
channel_axis = None
|
| 256 |
+
# guess at z axis:
|
| 257 |
+
z_axis = np.argmin(dimension_lengths)
|
| 258 |
+
|
| 259 |
+
elif img.ndim == 4:
|
| 260 |
+
# guess at channel axis:
|
| 261 |
+
channel_axis = np.argmin(dimension_lengths)
|
| 262 |
+
|
| 263 |
+
# guess at z axis:
|
| 264 |
+
# set channel axis to max so argmin works:
|
| 265 |
+
dimension_lengths[channel_axis] = max(dimension_lengths)
|
| 266 |
+
z_axis = np.argmin(dimension_lengths)
|
| 267 |
+
|
| 268 |
+
else:
|
| 269 |
+
raise ValueError(f'image shape error, 3D image must 3 or 4 dimensional. Number of dimensions: {img.ndim}')
|
| 270 |
+
|
| 271 |
+
try:
|
| 272 |
+
return transforms.convert_image(img, channel_axis=channel_axis, z_axis=z_axis, do_3D=True)
|
| 273 |
+
except Exception as e:
|
| 274 |
+
io_logger.critical("ERROR: could not read file, %s" % e)
|
| 275 |
+
io_logger.critical("ERROR: Guessed z_axis: %s, channel_axis: %s" % (z_axis, channel_axis))
|
| 276 |
+
return None
|
| 277 |
+
|
| 278 |
+
def remove_model(filename, delete=False):
|
| 279 |
+
""" remove model from .cellpose custom model list """
|
| 280 |
+
filename = os.path.split(filename)[-1]
|
| 281 |
+
from . import models
|
| 282 |
+
model_strings = models.get_user_models()
|
| 283 |
+
if len(model_strings) > 0:
|
| 284 |
+
with open(models.MODEL_LIST_PATH, "w") as textfile:
|
| 285 |
+
for fname in model_strings:
|
| 286 |
+
textfile.write(fname + "\n")
|
| 287 |
+
else:
|
| 288 |
+
# write empty file
|
| 289 |
+
textfile = open(models.MODEL_LIST_PATH, "w")
|
| 290 |
+
textfile.close()
|
| 291 |
+
print(f"{filename} removed from custom model list")
|
| 292 |
+
if delete:
|
| 293 |
+
os.remove(os.fspath(models.MODEL_DIR.joinpath(fname)))
|
| 294 |
+
print("model deleted")
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def add_model(filename):
|
| 298 |
+
""" add model to .cellpose models folder to use with GUI or CLI """
|
| 299 |
+
from . import models
|
| 300 |
+
fname = os.path.split(filename)[-1]
|
| 301 |
+
try:
|
| 302 |
+
shutil.copyfile(filename, os.fspath(models.MODEL_DIR.joinpath(fname)))
|
| 303 |
+
except shutil.SameFileError:
|
| 304 |
+
pass
|
| 305 |
+
print(f"{filename} copied to models folder {os.fspath(models.MODEL_DIR)}")
|
| 306 |
+
if fname not in models.get_user_models():
|
| 307 |
+
with open(models.MODEL_LIST_PATH, "a") as textfile:
|
| 308 |
+
textfile.write(fname + "\n")
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def imsave(filename, arr):
|
| 312 |
+
"""
|
| 313 |
+
Saves an image array to a file.
|
| 314 |
+
|
| 315 |
+
Args:
|
| 316 |
+
filename (str): The name of the file to save the image to.
|
| 317 |
+
arr (numpy.ndarray): The image array to be saved.
|
| 318 |
+
|
| 319 |
+
Returns:
|
| 320 |
+
None
|
| 321 |
+
"""
|
| 322 |
+
ext = os.path.splitext(filename)[-1].lower()
|
| 323 |
+
if ext == ".tif" or ext == ".tiff":
|
| 324 |
+
tifffile.imwrite(filename, data=arr, compression="zlib")
|
| 325 |
+
else:
|
| 326 |
+
if len(arr.shape) > 2:
|
| 327 |
+
arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB)
|
| 328 |
+
cv2.imwrite(filename, arr)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False):
|
| 332 |
+
"""
|
| 333 |
+
Finds all images in a folder and its subfolders (if specified) with the given file extensions.
|
| 334 |
+
|
| 335 |
+
Args:
|
| 336 |
+
folder (str): The path to the folder to search for images.
|
| 337 |
+
mask_filter (str): The filter for mask files.
|
| 338 |
+
imf (str, optional): The additional filter for image files. Defaults to None.
|
| 339 |
+
look_one_level_down (bool, optional): Whether to search for images in subfolders. Defaults to False.
|
| 340 |
+
|
| 341 |
+
Returns:
|
| 342 |
+
list: A list of image file paths.
|
| 343 |
+
|
| 344 |
+
Raises:
|
| 345 |
+
ValueError: If no files are found in the specified folder.
|
| 346 |
+
ValueError: If no images are found in the specified folder with the supported file extensions.
|
| 347 |
+
ValueError: If no images are found in the specified folder without the mask or flow file endings.
|
| 348 |
+
"""
|
| 349 |
+
mask_filters = ["_cp_output", "_flows", "_flows_0", "_flows_1",
|
| 350 |
+
"_flows_2", "_cellprob", "_masks", mask_filter]
|
| 351 |
+
image_names = []
|
| 352 |
+
if imf is None:
|
| 353 |
+
imf = ""
|
| 354 |
+
|
| 355 |
+
folders = []
|
| 356 |
+
if look_one_level_down:
|
| 357 |
+
folders = natsorted(glob.glob(os.path.join(folder, "*/")))
|
| 358 |
+
folders.append(folder)
|
| 359 |
+
exts = [".png", ".jpg", ".jpeg", ".tif", ".tiff", ".flex", ".dax", ".nd2", ".nrrd"]
|
| 360 |
+
l0 = 0
|
| 361 |
+
al = 0
|
| 362 |
+
for folder in folders:
|
| 363 |
+
all_files = glob.glob(folder + "/*")
|
| 364 |
+
al += len(all_files)
|
| 365 |
+
for ext in exts:
|
| 366 |
+
image_names.extend(glob.glob(folder + f"/*{imf}{ext}"))
|
| 367 |
+
image_names.extend(glob.glob(folder + f"/*{imf}{ext.upper()}"))
|
| 368 |
+
l0 += len(image_names)
|
| 369 |
+
|
| 370 |
+
# return error if no files found
|
| 371 |
+
if al == 0:
|
| 372 |
+
raise ValueError("ERROR: no files in --dir folder ")
|
| 373 |
+
elif l0 == 0:
|
| 374 |
+
raise ValueError(
|
| 375 |
+
"ERROR: no images in --dir folder with extensions .png, .jpg, .jpeg, .tif, .tiff, .flex"
|
| 376 |
+
)
|
| 377 |
+
|
| 378 |
+
image_names = natsorted(image_names)
|
| 379 |
+
imn = []
|
| 380 |
+
for im in image_names:
|
| 381 |
+
imfile = os.path.splitext(im)[0]
|
| 382 |
+
igood = all([(len(imfile) > len(mask_filter) and
|
| 383 |
+
imfile[-len(mask_filter):] != mask_filter) or
|
| 384 |
+
len(imfile) <= len(mask_filter) for mask_filter in mask_filters])
|
| 385 |
+
if len(imf) > 0:
|
| 386 |
+
igood &= imfile[-len(imf):] == imf
|
| 387 |
+
if igood:
|
| 388 |
+
imn.append(im)
|
| 389 |
+
|
| 390 |
+
image_names = imn
|
| 391 |
+
|
| 392 |
+
# remove duplicates
|
| 393 |
+
image_names = [*set(image_names)]
|
| 394 |
+
image_names = natsorted(image_names)
|
| 395 |
+
|
| 396 |
+
if len(image_names) == 0:
|
| 397 |
+
raise ValueError(
|
| 398 |
+
"ERROR: no images in --dir folder without _masks or _flows or _cellprob ending")
|
| 399 |
+
|
| 400 |
+
return image_names
|
| 401 |
+
|
| 402 |
+
def get_label_files(image_names, mask_filter, imf=None):
|
| 403 |
+
"""
|
| 404 |
+
Get the label files corresponding to the given image names and mask filter.
|
| 405 |
+
|
| 406 |
+
Args:
|
| 407 |
+
image_names (list): List of image names.
|
| 408 |
+
mask_filter (str): Mask filter to be applied.
|
| 409 |
+
imf (str, optional): Image file extension. Defaults to None.
|
| 410 |
+
|
| 411 |
+
Returns:
|
| 412 |
+
tuple: A tuple containing the label file names and flow file names (if present).
|
| 413 |
+
"""
|
| 414 |
+
nimg = len(image_names)
|
| 415 |
+
label_names0 = [os.path.splitext(image_names[n])[0] for n in range(nimg)]
|
| 416 |
+
|
| 417 |
+
if imf is not None and len(imf) > 0:
|
| 418 |
+
label_names = [label_names0[n][:-len(imf)] for n in range(nimg)]
|
| 419 |
+
else:
|
| 420 |
+
label_names = label_names0
|
| 421 |
+
|
| 422 |
+
# check for flows
|
| 423 |
+
if os.path.exists(label_names0[0] + "_flows.tif"):
|
| 424 |
+
flow_names = [label_names0[n] + "_flows.tif" for n in range(nimg)]
|
| 425 |
+
else:
|
| 426 |
+
flow_names = [label_names[n] + "_flows.tif" for n in range(nimg)]
|
| 427 |
+
if not all([os.path.exists(flow) for flow in flow_names]):
|
| 428 |
+
io_logger.info(
|
| 429 |
+
"not all flows are present, running flow generation for all images")
|
| 430 |
+
flow_names = None
|
| 431 |
+
|
| 432 |
+
# check for masks
|
| 433 |
+
if mask_filter == "_seg.npy":
|
| 434 |
+
label_names = [label_names[n] + mask_filter for n in range(nimg)]
|
| 435 |
+
return label_names, None
|
| 436 |
+
|
| 437 |
+
if os.path.exists(label_names[0] + mask_filter + ".tif"):
|
| 438 |
+
label_names = [label_names[n] + mask_filter + ".tif" for n in range(nimg)]
|
| 439 |
+
elif os.path.exists(label_names[0] + mask_filter + ".tiff"):
|
| 440 |
+
label_names = [label_names[n] + mask_filter + ".tiff" for n in range(nimg)]
|
| 441 |
+
elif os.path.exists(label_names[0] + mask_filter + ".png"):
|
| 442 |
+
label_names = [label_names[n] + mask_filter + ".png" for n in range(nimg)]
|
| 443 |
+
# TODO, allow _seg.npy
|
| 444 |
+
#elif os.path.exists(label_names[0] + "_seg.npy"):
|
| 445 |
+
# io_logger.info("labels found as _seg.npy files, converting to tif")
|
| 446 |
+
else:
|
| 447 |
+
if not flow_names:
|
| 448 |
+
raise ValueError("labels not provided with correct --mask_filter")
|
| 449 |
+
else:
|
| 450 |
+
label_names = None
|
| 451 |
+
if not all([os.path.exists(label) for label in label_names]):
|
| 452 |
+
if not flow_names:
|
| 453 |
+
raise ValueError(
|
| 454 |
+
"labels not provided for all images in train and/or test set")
|
| 455 |
+
else:
|
| 456 |
+
label_names = None
|
| 457 |
+
|
| 458 |
+
return label_names, flow_names
|
| 459 |
+
|
| 460 |
+
|
| 461 |
+
def load_images_labels(tdir, mask_filter="_masks", image_filter=None,
|
| 462 |
+
look_one_level_down=False):
|
| 463 |
+
"""
|
| 464 |
+
Loads images and corresponding labels from a directory.
|
| 465 |
+
|
| 466 |
+
Args:
|
| 467 |
+
tdir (str): The directory path.
|
| 468 |
+
mask_filter (str, optional): The filter for mask files. Defaults to "_masks".
|
| 469 |
+
image_filter (str, optional): The filter for image files. Defaults to None.
|
| 470 |
+
look_one_level_down (bool, optional): Whether to look for files one level down. Defaults to False.
|
| 471 |
+
|
| 472 |
+
Returns:
|
| 473 |
+
tuple: A tuple containing a list of images, a list of labels, and a list of image names.
|
| 474 |
+
"""
|
| 475 |
+
image_names = get_image_files(tdir, mask_filter, image_filter, look_one_level_down)
|
| 476 |
+
nimg = len(image_names)
|
| 477 |
+
|
| 478 |
+
# training data
|
| 479 |
+
label_names, flow_names = get_label_files(image_names, mask_filter,
|
| 480 |
+
imf=image_filter)
|
| 481 |
+
|
| 482 |
+
images = []
|
| 483 |
+
labels = []
|
| 484 |
+
k = 0
|
| 485 |
+
for n in range(nimg):
|
| 486 |
+
if (os.path.isfile(label_names[n]) or
|
| 487 |
+
(flow_names is not None and os.path.isfile(flow_names[0]))):
|
| 488 |
+
image = imread(image_names[n])
|
| 489 |
+
if label_names is not None:
|
| 490 |
+
label = imread(label_names[n])
|
| 491 |
+
if flow_names is not None:
|
| 492 |
+
flow = imread(flow_names[n])
|
| 493 |
+
if flow.shape[0] < 4:
|
| 494 |
+
label = np.concatenate((label[np.newaxis, :, :], flow), axis=0)
|
| 495 |
+
else:
|
| 496 |
+
label = flow
|
| 497 |
+
images.append(image)
|
| 498 |
+
labels.append(label)
|
| 499 |
+
k += 1
|
| 500 |
+
io_logger.info(f"{k} / {nimg} images in {tdir} folder have labels")
|
| 501 |
+
return images, labels, image_names
|
| 502 |
+
|
| 503 |
+
def load_train_test_data(train_dir, test_dir=None, image_filter=None,
|
| 504 |
+
mask_filter="_masks", look_one_level_down=False):
|
| 505 |
+
"""
|
| 506 |
+
Loads training and testing data for a Cellpose model.
|
| 507 |
+
|
| 508 |
+
Args:
|
| 509 |
+
train_dir (str): The directory path containing the training data.
|
| 510 |
+
test_dir (str, optional): The directory path containing the testing data. Defaults to None.
|
| 511 |
+
image_filter (str, optional): The filter for selecting image files. Defaults to None.
|
| 512 |
+
mask_filter (str, optional): The filter for selecting mask files. Defaults to "_masks".
|
| 513 |
+
look_one_level_down (bool, optional): Whether to look for data in subdirectories of train_dir and test_dir. Defaults to False.
|
| 514 |
+
|
| 515 |
+
Returns:
|
| 516 |
+
images, labels, image_names, test_images, test_labels, test_image_names
|
| 517 |
+
|
| 518 |
+
"""
|
| 519 |
+
images, labels, image_names = load_images_labels(train_dir, mask_filter,
|
| 520 |
+
image_filter, look_one_level_down)
|
| 521 |
+
# testing data
|
| 522 |
+
test_images, test_labels, test_image_names = None, None, None
|
| 523 |
+
if test_dir is not None:
|
| 524 |
+
test_images, test_labels, test_image_names = load_images_labels(
|
| 525 |
+
test_dir, mask_filter, image_filter, look_one_level_down)
|
| 526 |
+
|
| 527 |
+
return images, labels, image_names, test_images, test_labels, test_image_names
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def masks_flows_to_seg(images, masks, flows, file_names,
|
| 531 |
+
channels=None,
|
| 532 |
+
imgs_restore=None, restore_type=None, ratio=1.):
|
| 533 |
+
"""Save output of model eval to be loaded in GUI.
|
| 534 |
+
|
| 535 |
+
Can be list output (run on multiple images) or single output (run on single image).
|
| 536 |
+
|
| 537 |
+
Saved to file_names[k]+"_seg.npy".
|
| 538 |
+
|
| 539 |
+
Args:
|
| 540 |
+
images (list): Images input into cellpose.
|
| 541 |
+
masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
|
| 542 |
+
flows (list): Flows output from Cellpose.eval.
|
| 543 |
+
file_names (list, str): Names of files of images.
|
| 544 |
+
diams (float array): Diameters used to run Cellpose. Defaults to 30. TODO: remove this
|
| 545 |
+
channels (list, int, optional): Channels used to run Cellpose. Defaults to None.
|
| 546 |
+
|
| 547 |
+
Returns:
|
| 548 |
+
None
|
| 549 |
+
"""
|
| 550 |
+
|
| 551 |
+
if channels is None:
|
| 552 |
+
channels = [0, 0]
|
| 553 |
+
|
| 554 |
+
if isinstance(masks, list):
|
| 555 |
+
if imgs_restore is None:
|
| 556 |
+
imgs_restore = [None] * len(masks)
|
| 557 |
+
if isinstance(file_names, str):
|
| 558 |
+
file_names = [file_names] * len(masks)
|
| 559 |
+
for k, [image, mask, flow,
|
| 560 |
+
# diam,
|
| 561 |
+
file_name, img_restore
|
| 562 |
+
] in enumerate(zip(images, masks, flows,
|
| 563 |
+
# diams,
|
| 564 |
+
file_names,
|
| 565 |
+
imgs_restore)):
|
| 566 |
+
channels_img = channels
|
| 567 |
+
if channels_img is not None and len(channels) > 2:
|
| 568 |
+
channels_img = channels[k]
|
| 569 |
+
masks_flows_to_seg(image, mask, flow, file_name,
|
| 570 |
+
# diams=diam,
|
| 571 |
+
channels=channels_img, imgs_restore=img_restore,
|
| 572 |
+
restore_type=restore_type, ratio=ratio)
|
| 573 |
+
return
|
| 574 |
+
|
| 575 |
+
if len(channels) == 1:
|
| 576 |
+
channels = channels[0]
|
| 577 |
+
|
| 578 |
+
flowi = []
|
| 579 |
+
if flows[0].ndim == 3:
|
| 580 |
+
Ly, Lx = masks.shape[-2:]
|
| 581 |
+
flowi.append(
|
| 582 |
+
cv2.resize(flows[0], (Lx, Ly), interpolation=cv2.INTER_NEAREST)[np.newaxis,
|
| 583 |
+
...])
|
| 584 |
+
else:
|
| 585 |
+
flowi.append(flows[0])
|
| 586 |
+
|
| 587 |
+
if flows[0].ndim == 3:
|
| 588 |
+
cellprob = (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(
|
| 589 |
+
np.uint8)
|
| 590 |
+
cellprob = cv2.resize(cellprob, (Lx, Ly), interpolation=cv2.INTER_NEAREST)
|
| 591 |
+
flowi.append(cellprob[np.newaxis, ...])
|
| 592 |
+
flowi.append(np.zeros(flows[0].shape, dtype=np.uint8))
|
| 593 |
+
flowi[-1] = flowi[-1][np.newaxis, ...]
|
| 594 |
+
else:
|
| 595 |
+
flowi.append(
|
| 596 |
+
(np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(np.uint8))
|
| 597 |
+
flowi.append((flows[1][0] / 10 * 127 + 127).astype(np.uint8))
|
| 598 |
+
if len(flows) > 2:
|
| 599 |
+
if len(flows) > 3:
|
| 600 |
+
flowi.append(flows[3])
|
| 601 |
+
else:
|
| 602 |
+
flowi.append([])
|
| 603 |
+
flowi.append(np.concatenate((flows[1], flows[2][np.newaxis, ...]), axis=0))
|
| 604 |
+
outlines = masks * utils.masks_to_outlines(masks)
|
| 605 |
+
base = os.path.splitext(file_names)[0]
|
| 606 |
+
|
| 607 |
+
dat = {
|
| 608 |
+
"outlines":
|
| 609 |
+
outlines.astype(np.uint16) if outlines.max() < 2**16 -
|
| 610 |
+
1 else outlines.astype(np.uint32),
|
| 611 |
+
"masks":
|
| 612 |
+
masks.astype(np.uint16) if outlines.max() < 2**16 -
|
| 613 |
+
1 else masks.astype(np.uint32),
|
| 614 |
+
"chan_choose":
|
| 615 |
+
channels,
|
| 616 |
+
"ismanual":
|
| 617 |
+
np.zeros(masks.max(), bool),
|
| 618 |
+
"filename":
|
| 619 |
+
file_names,
|
| 620 |
+
"flows":
|
| 621 |
+
flowi,
|
| 622 |
+
"diameter":
|
| 623 |
+
np.nan
|
| 624 |
+
}
|
| 625 |
+
if restore_type is not None and imgs_restore is not None:
|
| 626 |
+
dat["restore"] = restore_type
|
| 627 |
+
dat["ratio"] = ratio
|
| 628 |
+
dat["img_restore"] = imgs_restore
|
| 629 |
+
|
| 630 |
+
np.save(base + "_seg.npy", dat)
|
| 631 |
+
|
| 632 |
+
def save_to_png(images, masks, flows, file_names):
|
| 633 |
+
""" deprecated (runs io.save_masks with png=True)
|
| 634 |
+
|
| 635 |
+
does not work for 3D images
|
| 636 |
+
|
| 637 |
+
"""
|
| 638 |
+
save_masks(images, masks, flows, file_names, png=True)
|
| 639 |
+
|
| 640 |
+
|
| 641 |
+
def save_rois(masks, file_name, multiprocessing=None):
|
| 642 |
+
""" save masks to .roi files in .zip archive for ImageJ/Fiji
|
| 643 |
+
|
| 644 |
+
Args:
|
| 645 |
+
masks (np.ndarray): masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels
|
| 646 |
+
file_name (str): name to save the .zip file to
|
| 647 |
+
|
| 648 |
+
Returns:
|
| 649 |
+
None
|
| 650 |
+
"""
|
| 651 |
+
outlines = utils.outlines_list(masks, multiprocessing=multiprocessing)
|
| 652 |
+
nonempty_outlines = [outline for outline in outlines if len(outline)!=0]
|
| 653 |
+
if len(outlines)!=len(nonempty_outlines):
|
| 654 |
+
print(f"empty outlines found, saving {len(nonempty_outlines)} ImageJ ROIs to .zip archive.")
|
| 655 |
+
rois = [ImagejRoi.frompoints(outline) for outline in nonempty_outlines]
|
| 656 |
+
file_name = os.path.splitext(file_name)[0] + '_rois.zip'
|
| 657 |
+
|
| 658 |
+
|
| 659 |
+
# Delete file if it exists; the roifile lib appends to existing zip files.
|
| 660 |
+
# If the user removed a mask it will still be in the zip file
|
| 661 |
+
if os.path.exists(file_name):
|
| 662 |
+
os.remove(file_name)
|
| 663 |
+
|
| 664 |
+
roiwrite(file_name, rois)
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[0, 0],
|
| 668 |
+
suffix="_cp_masks", save_flows=False, save_outlines=False, dir_above=False,
|
| 669 |
+
in_folders=False, savedir=None, save_txt=False, save_mpl=False):
|
| 670 |
+
""" Save masks + nicely plotted segmentation image to png and/or tiff.
|
| 671 |
+
|
| 672 |
+
Can save masks, flows to different directories, if in_folders is True.
|
| 673 |
+
|
| 674 |
+
If png, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.png".
|
| 675 |
+
|
| 676 |
+
If tif, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.tif".
|
| 677 |
+
|
| 678 |
+
If png and matplotlib installed, full segmentation figure is saved to file_names[k]+"_cp.png".
|
| 679 |
+
|
| 680 |
+
Only tif option works for 3D data, and only tif option works for empty masks.
|
| 681 |
+
|
| 682 |
+
Args:
|
| 683 |
+
images (list): Images input into cellpose.
|
| 684 |
+
masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
|
| 685 |
+
flows (list): Flows output from Cellpose.eval.
|
| 686 |
+
file_names (list, str): Names of files of images.
|
| 687 |
+
png (bool, optional): Save masks to PNG. Defaults to True.
|
| 688 |
+
tif (bool, optional): Save masks to TIF. Defaults to False.
|
| 689 |
+
channels (list, int, optional): Channels used to run Cellpose. Defaults to [0,0].
|
| 690 |
+
suffix (str, optional): Add name to saved masks. Defaults to "_cp_masks".
|
| 691 |
+
save_flows (bool, optional): Save flows output from Cellpose.eval. Defaults to False.
|
| 692 |
+
save_outlines (bool, optional): Save outlines of masks. Defaults to False.
|
| 693 |
+
dir_above (bool, optional): Save masks/flows in directory above. Defaults to False.
|
| 694 |
+
in_folders (bool, optional): Save masks/flows in separate folders. Defaults to False.
|
| 695 |
+
savedir (str, optional): Absolute path where images will be saved. If None, saves to image directory. Defaults to None.
|
| 696 |
+
save_txt (bool, optional): Save masks as list of outlines for ImageJ. Defaults to False.
|
| 697 |
+
save_mpl (bool, optional): If True, saves a matplotlib figure of the original image/segmentation/flows. Does not work for 3D.
|
| 698 |
+
This takes a long time for large images. Defaults to False.
|
| 699 |
+
|
| 700 |
+
Returns:
|
| 701 |
+
None
|
| 702 |
+
"""
|
| 703 |
+
|
| 704 |
+
if isinstance(masks, list):
|
| 705 |
+
for image, mask, flow, file_name in zip(images, masks, flows, file_names):
|
| 706 |
+
save_masks(image, mask, flow, file_name, png=png, tif=tif, suffix=suffix,
|
| 707 |
+
dir_above=dir_above, save_flows=save_flows,
|
| 708 |
+
save_outlines=save_outlines, savedir=savedir, save_txt=save_txt,
|
| 709 |
+
in_folders=in_folders, save_mpl=save_mpl)
|
| 710 |
+
return
|
| 711 |
+
|
| 712 |
+
if masks.ndim > 2 and not tif:
|
| 713 |
+
raise ValueError("cannot save 3D outputs as PNG, use tif option instead")
|
| 714 |
+
|
| 715 |
+
if masks.max() == 0:
|
| 716 |
+
io_logger.warning("no masks found, will not save PNG or outlines")
|
| 717 |
+
if not tif:
|
| 718 |
+
return
|
| 719 |
+
else:
|
| 720 |
+
png = False
|
| 721 |
+
save_outlines = False
|
| 722 |
+
save_flows = False
|
| 723 |
+
save_txt = False
|
| 724 |
+
|
| 725 |
+
if savedir is None:
|
| 726 |
+
if dir_above:
|
| 727 |
+
savedir = Path(file_names).parent.parent.absolute(
|
| 728 |
+
) #go up a level to save in its own folder
|
| 729 |
+
else:
|
| 730 |
+
savedir = Path(file_names).parent.absolute()
|
| 731 |
+
|
| 732 |
+
check_dir(savedir)
|
| 733 |
+
|
| 734 |
+
basename = os.path.splitext(os.path.basename(file_names))[0]
|
| 735 |
+
if in_folders:
|
| 736 |
+
maskdir = os.path.join(savedir, "masks")
|
| 737 |
+
outlinedir = os.path.join(savedir, "outlines")
|
| 738 |
+
txtdir = os.path.join(savedir, "txt_outlines")
|
| 739 |
+
flowdir = os.path.join(savedir, "flows")
|
| 740 |
+
else:
|
| 741 |
+
maskdir = savedir
|
| 742 |
+
outlinedir = savedir
|
| 743 |
+
txtdir = savedir
|
| 744 |
+
flowdir = savedir
|
| 745 |
+
|
| 746 |
+
check_dir(maskdir)
|
| 747 |
+
|
| 748 |
+
exts = []
|
| 749 |
+
if masks.ndim > 2:
|
| 750 |
+
png = False
|
| 751 |
+
tif = True
|
| 752 |
+
if png:
|
| 753 |
+
if masks.max() < 2**16:
|
| 754 |
+
masks = masks.astype(np.uint16)
|
| 755 |
+
exts.append(".png")
|
| 756 |
+
else:
|
| 757 |
+
png = False
|
| 758 |
+
tif = True
|
| 759 |
+
io_logger.warning(
|
| 760 |
+
"found more than 65535 masks in each image, cannot save PNG, saving as TIF"
|
| 761 |
+
)
|
| 762 |
+
if tif:
|
| 763 |
+
exts.append(".tif")
|
| 764 |
+
|
| 765 |
+
# save masks
|
| 766 |
+
with warnings.catch_warnings():
|
| 767 |
+
warnings.simplefilter("ignore")
|
| 768 |
+
for ext in exts:
|
| 769 |
+
imsave(os.path.join(maskdir, basename + suffix + ext), masks)
|
| 770 |
+
|
| 771 |
+
if save_mpl and png and MATPLOTLIB and not min(images.shape) > 3:
|
| 772 |
+
# Make and save original/segmentation/flows image
|
| 773 |
+
|
| 774 |
+
img = images.copy()
|
| 775 |
+
if img.ndim < 3:
|
| 776 |
+
img = img[:, :, np.newaxis]
|
| 777 |
+
elif img.shape[0] < 8:
|
| 778 |
+
np.transpose(img, (1, 2, 0))
|
| 779 |
+
|
| 780 |
+
fig = plt.figure(figsize=(12, 3))
|
| 781 |
+
plot.show_segmentation(fig, img, masks, flows[0])
|
| 782 |
+
fig.savefig(os.path.join(savedir, basename + "_cp_output" + suffix + ".png"),
|
| 783 |
+
dpi=300)
|
| 784 |
+
plt.close(fig)
|
| 785 |
+
|
| 786 |
+
# ImageJ txt outline files
|
| 787 |
+
if masks.ndim < 3 and save_txt:
|
| 788 |
+
check_dir(txtdir)
|
| 789 |
+
outlines = utils.outlines_list(masks)
|
| 790 |
+
outlines_to_text(os.path.join(txtdir, basename), outlines)
|
| 791 |
+
|
| 792 |
+
# RGB outline images
|
| 793 |
+
if masks.ndim < 3 and save_outlines:
|
| 794 |
+
check_dir(outlinedir)
|
| 795 |
+
outlines = utils.masks_to_outlines(masks)
|
| 796 |
+
outX, outY = np.nonzero(outlines)
|
| 797 |
+
img0 = transforms.normalize99(images)
|
| 798 |
+
if img0.shape[0] < 4:
|
| 799 |
+
img0 = np.transpose(img0, (1, 2, 0))
|
| 800 |
+
if img0.shape[-1] < 3 or img0.ndim < 3:
|
| 801 |
+
img0 = plot.image_to_rgb(img0, channels=channels)
|
| 802 |
+
else:
|
| 803 |
+
if img0.max() <= 50.0:
|
| 804 |
+
img0 = np.uint8(np.clip(img0 * 255, 0, 1))
|
| 805 |
+
imgout = img0.copy()
|
| 806 |
+
imgout[outX, outY] = np.array([255, 0, 0]) #pure red
|
| 807 |
+
imsave(os.path.join(outlinedir, basename + "_outlines" + suffix + ".png"),
|
| 808 |
+
imgout)
|
| 809 |
+
|
| 810 |
+
# save RGB flow picture
|
| 811 |
+
if masks.ndim < 3 and save_flows:
|
| 812 |
+
check_dir(flowdir)
|
| 813 |
+
imsave(os.path.join(flowdir, basename + "_flows" + suffix + ".tif"),
|
| 814 |
+
(flows[0] * (2**16 - 1)).astype(np.uint16))
|
| 815 |
+
#save full flow data
|
| 816 |
+
imsave(os.path.join(flowdir, basename + '_dP' + suffix + '.tif'), flows[1])
|
models/seg_post_model/cellpose/metrics.py
ADDED
|
@@ -0,0 +1,205 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import numpy as np
|
| 5 |
+
from . import utils
|
| 6 |
+
from scipy.optimize import linear_sum_assignment
|
| 7 |
+
from scipy.ndimage import convolve
|
| 8 |
+
from scipy.sparse import csr_matrix
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def mask_ious(masks_true, masks_pred):
|
| 12 |
+
"""Return best-matched masks."""
|
| 13 |
+
iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:]
|
| 14 |
+
n_min = min(iou.shape[0], iou.shape[1])
|
| 15 |
+
costs = -(iou >= 0.5).astype(float) - iou / (2 * n_min)
|
| 16 |
+
true_ind, pred_ind = linear_sum_assignment(costs)
|
| 17 |
+
iout = np.zeros(masks_true.max())
|
| 18 |
+
iout[true_ind] = iou[true_ind, pred_ind]
|
| 19 |
+
preds = np.zeros(masks_true.max(), "int")
|
| 20 |
+
preds[true_ind] = pred_ind + 1
|
| 21 |
+
return iout, preds
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def boundary_scores(masks_true, masks_pred, scales):
|
| 25 |
+
"""
|
| 26 |
+
Calculate boundary precision, recall, and F-score.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
masks_true (list): List of true masks.
|
| 30 |
+
masks_pred (list): List of predicted masks.
|
| 31 |
+
scales (list): List of scales.
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
tuple: A tuple containing precision, recall, and F-score arrays.
|
| 35 |
+
"""
|
| 36 |
+
diams = [utils.diameters(lbl)[0] for lbl in masks_true]
|
| 37 |
+
precision = np.zeros((len(scales), len(masks_true)))
|
| 38 |
+
recall = np.zeros((len(scales), len(masks_true)))
|
| 39 |
+
fscore = np.zeros((len(scales), len(masks_true)))
|
| 40 |
+
for j, scale in enumerate(scales):
|
| 41 |
+
for n in range(len(masks_true)):
|
| 42 |
+
diam = max(1, scale * diams[n])
|
| 43 |
+
rs, ys, xs = utils.circleMask([int(np.ceil(diam)), int(np.ceil(diam))])
|
| 44 |
+
filt = (rs <= diam).astype(np.float32)
|
| 45 |
+
otrue = utils.masks_to_outlines(masks_true[n])
|
| 46 |
+
otrue = convolve(otrue, filt)
|
| 47 |
+
opred = utils.masks_to_outlines(masks_pred[n])
|
| 48 |
+
opred = convolve(opred, filt)
|
| 49 |
+
tp = np.logical_and(otrue == 1, opred == 1).sum()
|
| 50 |
+
fp = np.logical_and(otrue == 0, opred == 1).sum()
|
| 51 |
+
fn = np.logical_and(otrue == 1, opred == 0).sum()
|
| 52 |
+
precision[j, n] = tp / (tp + fp)
|
| 53 |
+
recall[j, n] = tp / (tp + fn)
|
| 54 |
+
fscore[j] = 2 * precision[j] * recall[j] / (precision[j] + recall[j])
|
| 55 |
+
return precision, recall, fscore
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def aggregated_jaccard_index(masks_true, masks_pred):
|
| 59 |
+
"""
|
| 60 |
+
AJI = intersection of all matched masks / union of all masks
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
|
| 64 |
+
where 0=NO masks; 1,2... are mask labels
|
| 65 |
+
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
|
| 66 |
+
np.ndarray (int) where 0=NO masks; 1,2... are mask labels
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
aji (float): aggregated jaccard index for each set of masks
|
| 70 |
+
"""
|
| 71 |
+
aji = np.zeros(len(masks_true))
|
| 72 |
+
for n in range(len(masks_true)):
|
| 73 |
+
iout, preds = mask_ious(masks_true[n], masks_pred[n])
|
| 74 |
+
inds = np.arange(0, masks_true[n].max(), 1, int)
|
| 75 |
+
overlap = _label_overlap(masks_true[n], masks_pred[n])
|
| 76 |
+
union = np.logical_or(masks_true[n] > 0, masks_pred[n] > 0).sum()
|
| 77 |
+
overlap = overlap[inds[preds > 0] + 1, preds[preds > 0].astype(int)]
|
| 78 |
+
aji[n] = overlap.sum() / union
|
| 79 |
+
return aji
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]):
|
| 83 |
+
"""
|
| 84 |
+
Average precision estimation: AP = TP / (TP + FP + FN)
|
| 85 |
+
|
| 86 |
+
This function is based heavily on the *fast* stardist matching functions
|
| 87 |
+
(https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py)
|
| 88 |
+
|
| 89 |
+
Args:
|
| 90 |
+
masks_true (list of np.ndarrays (int) or np.ndarray (int)):
|
| 91 |
+
where 0=NO masks; 1,2... are mask labels
|
| 92 |
+
masks_pred (list of np.ndarrays (int) or np.ndarray (int)):
|
| 93 |
+
np.ndarray (int) where 0=NO masks; 1,2... are mask labels
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
ap (array [len(masks_true) x len(threshold)]):
|
| 97 |
+
average precision at thresholds
|
| 98 |
+
tp (array [len(masks_true) x len(threshold)]):
|
| 99 |
+
number of true positives at thresholds
|
| 100 |
+
fp (array [len(masks_true) x len(threshold)]):
|
| 101 |
+
number of false positives at thresholds
|
| 102 |
+
fn (array [len(masks_true) x len(threshold)]):
|
| 103 |
+
number of false negatives at thresholds
|
| 104 |
+
"""
|
| 105 |
+
not_list = False
|
| 106 |
+
if not isinstance(masks_true, list):
|
| 107 |
+
masks_true = [masks_true]
|
| 108 |
+
masks_pred = [masks_pred]
|
| 109 |
+
not_list = True
|
| 110 |
+
if not isinstance(threshold, list) and not isinstance(threshold, np.ndarray):
|
| 111 |
+
threshold = [threshold]
|
| 112 |
+
|
| 113 |
+
if len(masks_true) != len(masks_pred):
|
| 114 |
+
raise ValueError(
|
| 115 |
+
"metrics.average_precision requires len(masks_true)==len(masks_pred)")
|
| 116 |
+
|
| 117 |
+
ap = np.zeros((len(masks_true), len(threshold)), np.float32)
|
| 118 |
+
tp = np.zeros((len(masks_true), len(threshold)), np.float32)
|
| 119 |
+
fp = np.zeros((len(masks_true), len(threshold)), np.float32)
|
| 120 |
+
fn = np.zeros((len(masks_true), len(threshold)), np.float32)
|
| 121 |
+
n_true = np.array([len(np.unique(mt)) - 1 for mt in masks_true])
|
| 122 |
+
n_pred = np.array([len(np.unique(mp)) - 1 for mp in masks_pred])
|
| 123 |
+
|
| 124 |
+
for n in range(len(masks_true)):
|
| 125 |
+
#_,mt = np.reshape(np.unique(masks_true[n], return_index=True), masks_pred[n].shape)
|
| 126 |
+
if n_pred[n] > 0:
|
| 127 |
+
iou = _intersection_over_union(masks_true[n], masks_pred[n])[1:, 1:]
|
| 128 |
+
for k, th in enumerate(threshold):
|
| 129 |
+
tp[n, k] = _true_positive(iou, th)
|
| 130 |
+
fp[n] = n_pred[n] - tp[n]
|
| 131 |
+
fn[n] = n_true[n] - tp[n]
|
| 132 |
+
ap[n] = tp[n] / (tp[n] + fp[n] + fn[n])
|
| 133 |
+
|
| 134 |
+
if not_list:
|
| 135 |
+
ap, tp, fp, fn = ap[0], tp[0], fp[0], fn[0]
|
| 136 |
+
return ap, tp, fp, fn
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def _intersection_over_union(masks_true, masks_pred):
|
| 140 |
+
"""Calculate the intersection over union of all mask pairs.
|
| 141 |
+
|
| 142 |
+
Parameters:
|
| 143 |
+
masks_true (np.ndarray, int): Ground truth masks, where 0=NO masks; 1,2... are mask labels.
|
| 144 |
+
masks_pred (np.ndarray, int): Predicted masks, where 0=NO masks; 1,2... are mask labels.
|
| 145 |
+
|
| 146 |
+
Returns:
|
| 147 |
+
iou (np.ndarray, float): Matrix of IOU pairs of size [x.max()+1, y.max()+1].
|
| 148 |
+
|
| 149 |
+
How it works:
|
| 150 |
+
The overlap matrix is a lookup table of the area of intersection
|
| 151 |
+
between each set of labels (true and predicted). The true labels
|
| 152 |
+
are taken to be along axis 0, and the predicted labels are taken
|
| 153 |
+
to be along axis 1. The sum of the overlaps along axis 0 is thus
|
| 154 |
+
an array giving the total overlap of the true labels with each of
|
| 155 |
+
the predicted labels, and likewise the sum over axis 1 is the
|
| 156 |
+
total overlap of the predicted labels with each of the true labels.
|
| 157 |
+
Because the label 0 (background) is included, this sum is guaranteed
|
| 158 |
+
to reconstruct the total area of each label. Adding this row and
|
| 159 |
+
column vectors gives a 2D array with the areas of every label pair
|
| 160 |
+
added together. This is equivalent to the union of the label areas
|
| 161 |
+
except for the duplicated overlap area, so the overlap matrix is
|
| 162 |
+
subtracted to find the union matrix.
|
| 163 |
+
"""
|
| 164 |
+
if masks_true.size != masks_pred.size:
|
| 165 |
+
raise ValueError(f"masks_true.size {masks_true.shape} != masks_pred.size {masks_pred.shape}")
|
| 166 |
+
overlap = csr_matrix((np.ones((masks_true.size,), "int"),
|
| 167 |
+
(masks_true.flatten(), masks_pred.flatten())),
|
| 168 |
+
shape=(masks_true.max()+1, masks_pred.max()+1))
|
| 169 |
+
overlap = overlap.toarray()
|
| 170 |
+
n_pixels_pred = np.sum(overlap, axis=0, keepdims=True)
|
| 171 |
+
n_pixels_true = np.sum(overlap, axis=1, keepdims=True)
|
| 172 |
+
iou = overlap / (n_pixels_pred + n_pixels_true - overlap)
|
| 173 |
+
iou[np.isnan(iou)] = 0.0
|
| 174 |
+
return iou
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _true_positive(iou, th):
|
| 178 |
+
"""Calculate the true positive at threshold th.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
iou (float, np.ndarray): Array of IOU pairs.
|
| 182 |
+
th (float): Threshold on IOU for positive label.
|
| 183 |
+
|
| 184 |
+
Returns:
|
| 185 |
+
tp (float): Number of true positives at threshold.
|
| 186 |
+
|
| 187 |
+
How it works:
|
| 188 |
+
(1) Find minimum number of masks.
|
| 189 |
+
(2) Define cost matrix; for a given threshold, each element is negative
|
| 190 |
+
the higher the IoU is (perfect IoU is 1, worst is 0). The second term
|
| 191 |
+
gets more negative with higher IoU, but less negative with greater
|
| 192 |
+
n_min (but that's a constant...).
|
| 193 |
+
(3) Solve the linear sum assignment problem. The costs array defines the cost
|
| 194 |
+
of matching a true label with a predicted label, so the problem is to
|
| 195 |
+
find the set of pairings that minimizes this cost. The scipy.optimize
|
| 196 |
+
function gives the ordered lists of corresponding true and predicted labels.
|
| 197 |
+
(4) Extract the IoUs from these pairings and then threshold to get a boolean array
|
| 198 |
+
whose sum is the number of true positives that is returned.
|
| 199 |
+
"""
|
| 200 |
+
n_min = min(iou.shape[0], iou.shape[1])
|
| 201 |
+
costs = -(iou >= th).astype(float) - iou / (2 * n_min)
|
| 202 |
+
true_ind, pred_ind = linear_sum_assignment(costs)
|
| 203 |
+
match_ok = iou[true_ind, pred_ind] >= th
|
| 204 |
+
tp = match_ok.sum()
|
| 205 |
+
return tp
|
models/seg_post_model/cellpose/models.py
ADDED
|
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import os, time
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import numpy as np
|
| 8 |
+
from tqdm import trange
|
| 9 |
+
import torch
|
| 10 |
+
from scipy.ndimage import gaussian_filter
|
| 11 |
+
import gc
|
| 12 |
+
import cv2
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
models_logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
from . import transforms, dynamics, utils, plot
|
| 19 |
+
from .vit_sam import Transformer
|
| 20 |
+
from .core import assign_device, run_net, run_3D
|
| 21 |
+
|
| 22 |
+
_CPSAM_MODEL_URL = "https://huggingface.co/mouseland/cellpose-sam/resolve/main/cpsam"
|
| 23 |
+
_MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH")
|
| 24 |
+
# _MODEL_DIR_DEFAULT = Path.home().joinpath(".cellpose", "models")
|
| 25 |
+
_MODEL_DIR_DEFAULT = Path("/media/data1/huix/seg/cellpose_models")
|
| 26 |
+
MODEL_DIR = Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT
|
| 27 |
+
|
| 28 |
+
MODEL_NAMES = ["cpsam"]
|
| 29 |
+
|
| 30 |
+
MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt"))
|
| 31 |
+
|
| 32 |
+
normalize_default = {
|
| 33 |
+
"lowhigh": None,
|
| 34 |
+
"percentile": None,
|
| 35 |
+
"normalize": True,
|
| 36 |
+
"norm3D": True,
|
| 37 |
+
"sharpen_radius": 0,
|
| 38 |
+
"smooth_radius": 0,
|
| 39 |
+
"tile_norm_blocksize": 0,
|
| 40 |
+
"tile_norm_smooth3D": 1,
|
| 41 |
+
"invert": False
|
| 42 |
+
}
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
# def model_path(model_type, model_index=0):
|
| 46 |
+
# return cache_CPSAM_model_path()
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
# def cache_CPSAM_model_path():
|
| 50 |
+
# MODEL_DIR.mkdir(parents=True, exist_ok=True)
|
| 51 |
+
# cached_file = os.fspath(MODEL_DIR.joinpath('cpsam'))
|
| 52 |
+
# if not os.path.exists(cached_file):
|
| 53 |
+
# models_logger.info('Downloading: "{}" to {}\n'.format(_CPSAM_MODEL_URL, cached_file))
|
| 54 |
+
# utils.download_url_to_file(_CPSAM_MODEL_URL, cached_file, progress=True)
|
| 55 |
+
# return cached_file
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def get_user_models():
|
| 59 |
+
model_strings = []
|
| 60 |
+
if os.path.exists(MODEL_LIST_PATH):
|
| 61 |
+
with open(MODEL_LIST_PATH, "r") as textfile:
|
| 62 |
+
lines = [line.rstrip() for line in textfile]
|
| 63 |
+
if len(lines) > 0:
|
| 64 |
+
model_strings.extend(lines)
|
| 65 |
+
return model_strings
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
class CellposeModel():
|
| 69 |
+
"""
|
| 70 |
+
Class representing a Cellpose model.
|
| 71 |
+
|
| 72 |
+
Attributes:
|
| 73 |
+
diam_mean (float): Mean "diameter" value for the model.
|
| 74 |
+
builtin (bool): Whether the model is a built-in model or not.
|
| 75 |
+
device (torch device): Device used for model running / training.
|
| 76 |
+
nclasses (int): Number of classes in the model.
|
| 77 |
+
nbase (list): List of base values for the model.
|
| 78 |
+
net (CPnet): Cellpose network.
|
| 79 |
+
pretrained_model (str): Path to pretrained cellpose model.
|
| 80 |
+
pretrained_model_ortho (str): Path or model_name for pretrained cellpose model for ortho views in 3D.
|
| 81 |
+
backbone (str): Type of network ("default" is the standard res-unet, "transformer" for the segformer).
|
| 82 |
+
|
| 83 |
+
Methods:
|
| 84 |
+
__init__(self, gpu=False, pretrained_model=False, model_type=None, diam_mean=30., device=None):
|
| 85 |
+
Initialize the CellposeModel.
|
| 86 |
+
|
| 87 |
+
eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, progress=None):
|
| 88 |
+
Segment list of images x, or 4D array - Z x C x Y x X.
|
| 89 |
+
|
| 90 |
+
"""
|
| 91 |
+
|
| 92 |
+
def __init__(self, gpu=False, pretrained_model="", model_type=None,
|
| 93 |
+
diam_mean=None, device=None, nchan=None, use_bfloat16=True, vit_checkpoint=None):
|
| 94 |
+
"""
|
| 95 |
+
Initialize the CellposeModel.
|
| 96 |
+
|
| 97 |
+
Parameters:
|
| 98 |
+
gpu (bool, optional): Whether or not to save model to GPU, will check if GPU available.
|
| 99 |
+
pretrained_model (str or list of strings, optional): Full path to pretrained cellpose model(s), if None or False, no model loaded.
|
| 100 |
+
model_type (str, optional): Any model that is available in the GUI, use name in GUI e.g. "livecell" (can be user-trained or model zoo).
|
| 101 |
+
diam_mean (float, optional): Mean "diameter", 30. is built-in value for "cyto" model; 17. is built-in value for "nuclei" model; if saved in custom model file (cellpose>=2.0) then it will be loaded automatically and overwrite this value.
|
| 102 |
+
device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")).
|
| 103 |
+
use_bfloat16 (bool, optional): Use 16bit float precision instead of 32bit for model weights. Default to 16bit (True).
|
| 104 |
+
"""
|
| 105 |
+
# if diam_mean is not None:
|
| 106 |
+
# models_logger.warning(
|
| 107 |
+
# "diam_mean argument are not used in v4.0.1+. Ignoring this argument..."
|
| 108 |
+
# )
|
| 109 |
+
# if model_type is not None:
|
| 110 |
+
# models_logger.warning(
|
| 111 |
+
# "model_type argument is not used in v4.0.1+. Ignoring this argument..."
|
| 112 |
+
# )
|
| 113 |
+
# if nchan is not None:
|
| 114 |
+
# models_logger.warning("nchan argument is deprecated in v4.0.1+. Ignoring this argument")
|
| 115 |
+
|
| 116 |
+
### assign model device
|
| 117 |
+
self.device = assign_device(gpu=gpu)[0] if device is None else device
|
| 118 |
+
if torch.cuda.is_available():
|
| 119 |
+
device_gpu = self.device.type == "cuda"
|
| 120 |
+
elif torch.backends.mps.is_available():
|
| 121 |
+
device_gpu = self.device.type == "mps"
|
| 122 |
+
else:
|
| 123 |
+
device_gpu = False
|
| 124 |
+
self.gpu = device_gpu
|
| 125 |
+
|
| 126 |
+
if pretrained_model is None:
|
| 127 |
+
# raise ValueError("Must specify a pretrained model, training from scratch is not implemented")
|
| 128 |
+
pretrained_model = ""
|
| 129 |
+
|
| 130 |
+
### create neural network
|
| 131 |
+
if pretrained_model and not os.path.exists(pretrained_model):
|
| 132 |
+
# check if pretrained model is in the models directory
|
| 133 |
+
model_strings = get_user_models()
|
| 134 |
+
all_models = MODEL_NAMES.copy()
|
| 135 |
+
all_models.extend(model_strings)
|
| 136 |
+
if pretrained_model in all_models:
|
| 137 |
+
pretrained_model = os.path.join(MODEL_DIR, pretrained_model)
|
| 138 |
+
else:
|
| 139 |
+
pretrained_model = os.path.join(MODEL_DIR, "cpsam")
|
| 140 |
+
models_logger.warning(
|
| 141 |
+
f"pretrained model {pretrained_model} not found, using default model"
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
self.pretrained_model = pretrained_model
|
| 145 |
+
dtype = torch.bfloat16 if use_bfloat16 else torch.float32
|
| 146 |
+
self.net = Transformer(dtype=dtype, checkpoint=vit_checkpoint).to(self.device)
|
| 147 |
+
|
| 148 |
+
if os.path.exists(self.pretrained_model):
|
| 149 |
+
models_logger.info(f">>>> loading model {self.pretrained_model}")
|
| 150 |
+
self.net.load_model(self.pretrained_model, device=self.device)
|
| 151 |
+
# else:
|
| 152 |
+
# try:
|
| 153 |
+
# if os.path.split(self.pretrained_model)[-1] != 'cpsam':
|
| 154 |
+
# raise FileNotFoundError('model file not recognized')
|
| 155 |
+
# cache_CPSAM_model_path()
|
| 156 |
+
# self.net.load_model(self.pretrained_model, device=self.device)
|
| 157 |
+
# except:
|
| 158 |
+
# print("ViT not initialized")
|
| 159 |
+
|
| 160 |
+
|
| 161 |
+
def eval(self, x, feat=None, batch_size=8, resample=True, channels=None, channel_axis=None,
|
| 162 |
+
z_axis=None, normalize=True, invert=False, rescale=None, diameter=None,
|
| 163 |
+
flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None,
|
| 164 |
+
flow3D_smooth=0, stitch_threshold=0.0,
|
| 165 |
+
min_size=15, max_size_fraction=0.4, niter=None,
|
| 166 |
+
augment=False, tile_overlap=0.1, bsize=256,
|
| 167 |
+
compute_masks=True, progress=None):
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
# if rescale is not None:
|
| 171 |
+
# models_logger.warning("rescaling deprecated in v4.0.1+")
|
| 172 |
+
# if channels is not None:
|
| 173 |
+
# models_logger.warning("channels deprecated in v4.0.1+. If data contain more than 3 channels, only the first 3 channels will be used")
|
| 174 |
+
|
| 175 |
+
if isinstance(x, list) or x.squeeze().ndim == 5:
|
| 176 |
+
self.timing = []
|
| 177 |
+
masks, styles, flows = [], [], []
|
| 178 |
+
tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
|
| 179 |
+
nimg = len(x)
|
| 180 |
+
iterator = trange(nimg, file=tqdm_out,
|
| 181 |
+
mininterval=30) if nimg > 1 else range(nimg)
|
| 182 |
+
for i in iterator:
|
| 183 |
+
tic = time.time()
|
| 184 |
+
maski, flowi, stylei = self.eval(
|
| 185 |
+
x[i],
|
| 186 |
+
feat=None if feat is None else feat[i],
|
| 187 |
+
batch_size=batch_size,
|
| 188 |
+
channel_axis=channel_axis,
|
| 189 |
+
z_axis=z_axis,
|
| 190 |
+
normalize=normalize,
|
| 191 |
+
invert=invert,
|
| 192 |
+
diameter=diameter[i] if isinstance(diameter, list) or
|
| 193 |
+
isinstance(diameter, np.ndarray) else diameter,
|
| 194 |
+
do_3D=do_3D,
|
| 195 |
+
anisotropy=anisotropy,
|
| 196 |
+
augment=augment,
|
| 197 |
+
tile_overlap=tile_overlap,
|
| 198 |
+
bsize=bsize,
|
| 199 |
+
resample=resample,
|
| 200 |
+
flow_threshold=flow_threshold,
|
| 201 |
+
cellprob_threshold=cellprob_threshold,
|
| 202 |
+
compute_masks=compute_masks,
|
| 203 |
+
min_size=min_size,
|
| 204 |
+
max_size_fraction=max_size_fraction,
|
| 205 |
+
stitch_threshold=stitch_threshold,
|
| 206 |
+
flow3D_smooth=flow3D_smooth,
|
| 207 |
+
progress=progress,
|
| 208 |
+
niter=niter)
|
| 209 |
+
masks.append(maski)
|
| 210 |
+
flows.append(flowi)
|
| 211 |
+
styles.append(stylei)
|
| 212 |
+
self.timing.append(time.time() - tic)
|
| 213 |
+
return masks, flows, styles
|
| 214 |
+
|
| 215 |
+
############# actual eval code ############
|
| 216 |
+
# reshape image
|
| 217 |
+
x = transforms.convert_image(x, channel_axis=channel_axis,
|
| 218 |
+
z_axis=z_axis,
|
| 219 |
+
do_3D=(do_3D or stitch_threshold > 0))
|
| 220 |
+
|
| 221 |
+
# Add batch dimension if not present
|
| 222 |
+
if x.ndim < 4:
|
| 223 |
+
x = x[np.newaxis, ...]
|
| 224 |
+
if feat is not None:
|
| 225 |
+
if feat.ndim < 4:
|
| 226 |
+
feat = feat[np.newaxis, ...]
|
| 227 |
+
nimg = x.shape[0]
|
| 228 |
+
|
| 229 |
+
image_scaling = None
|
| 230 |
+
Ly_0 = x.shape[1]
|
| 231 |
+
Lx_0 = x.shape[2]
|
| 232 |
+
Lz_0 = None
|
| 233 |
+
if do_3D or stitch_threshold > 0:
|
| 234 |
+
Lz_0 = x.shape[0]
|
| 235 |
+
if diameter is not None:
|
| 236 |
+
image_scaling = 30. / diameter
|
| 237 |
+
x = transforms.resize_image(x,
|
| 238 |
+
Ly=int(x.shape[1] * image_scaling),
|
| 239 |
+
Lx=int(x.shape[2] * image_scaling))
|
| 240 |
+
if feat is not None:
|
| 241 |
+
feat = transforms.resize_image(feat,
|
| 242 |
+
Ly=int(feat.shape[1] * image_scaling),
|
| 243 |
+
Lx=int(feat.shape[2] * image_scaling))
|
| 244 |
+
|
| 245 |
+
|
| 246 |
+
# normalize image
|
| 247 |
+
normalize_params = normalize_default
|
| 248 |
+
if isinstance(normalize, dict):
|
| 249 |
+
normalize_params = {**normalize_params, **normalize}
|
| 250 |
+
elif not isinstance(normalize, bool):
|
| 251 |
+
raise ValueError("normalize parameter must be a bool or a dict")
|
| 252 |
+
else:
|
| 253 |
+
normalize_params["normalize"] = normalize
|
| 254 |
+
normalize_params["invert"] = invert
|
| 255 |
+
|
| 256 |
+
# pre-normalize if 3D stack for stitching or do_3D
|
| 257 |
+
do_normalization = True if normalize_params["normalize"] else False
|
| 258 |
+
if nimg > 1 and do_normalization and (stitch_threshold or do_3D):
|
| 259 |
+
normalize_params["norm3D"] = True if do_3D else normalize_params["norm3D"]
|
| 260 |
+
x = transforms.normalize_img(x, **normalize_params)
|
| 261 |
+
do_normalization = False # do not normalize again
|
| 262 |
+
else:
|
| 263 |
+
if normalize_params["norm3D"] and nimg > 1 and do_normalization:
|
| 264 |
+
models_logger.warning(
|
| 265 |
+
"normalize_params['norm3D'] is True but do_3D is False and stitch_threshold=0, so setting to False"
|
| 266 |
+
)
|
| 267 |
+
normalize_params["norm3D"] = False
|
| 268 |
+
if do_normalization:
|
| 269 |
+
x = transforms.normalize_img(x, **normalize_params)
|
| 270 |
+
|
| 271 |
+
if feat is not None:
|
| 272 |
+
if feat.shape[-1] > feat.shape[1]:
|
| 273 |
+
# transpose feat to have channels last
|
| 274 |
+
feat = np.moveaxis(feat, 1, -1)
|
| 275 |
+
|
| 276 |
+
# ajust the anisotropy when diameter is specified and images are resized:
|
| 277 |
+
if isinstance(anisotropy, (float, int)) and image_scaling:
|
| 278 |
+
anisotropy = image_scaling * anisotropy
|
| 279 |
+
|
| 280 |
+
dP, cellprob, styles = self._run_net(
|
| 281 |
+
x,
|
| 282 |
+
feat=feat,
|
| 283 |
+
augment=augment,
|
| 284 |
+
batch_size=batch_size,
|
| 285 |
+
tile_overlap=tile_overlap,
|
| 286 |
+
bsize=bsize,
|
| 287 |
+
do_3D=do_3D,
|
| 288 |
+
anisotropy=anisotropy)
|
| 289 |
+
|
| 290 |
+
if do_3D:
|
| 291 |
+
if flow3D_smooth > 0:
|
| 292 |
+
models_logger.info(f"smoothing flows with sigma={flow3D_smooth}")
|
| 293 |
+
dP = gaussian_filter(dP, (0, flow3D_smooth, flow3D_smooth, flow3D_smooth))
|
| 294 |
+
torch.cuda.empty_cache()
|
| 295 |
+
gc.collect()
|
| 296 |
+
|
| 297 |
+
if resample:
|
| 298 |
+
# upsample flows before computing them:
|
| 299 |
+
dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0)
|
| 300 |
+
cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
|
| 301 |
+
|
| 302 |
+
|
| 303 |
+
if compute_masks:
|
| 304 |
+
niter0 = 200
|
| 305 |
+
niter = niter0 if niter is None or niter == 0 else niter
|
| 306 |
+
masks = self._compute_masks(x.shape, dP, cellprob, flow_threshold=flow_threshold,
|
| 307 |
+
cellprob_threshold=cellprob_threshold, min_size=min_size,
|
| 308 |
+
max_size_fraction=max_size_fraction, niter=niter,
|
| 309 |
+
stitch_threshold=stitch_threshold, do_3D=do_3D)
|
| 310 |
+
else:
|
| 311 |
+
masks = np.zeros(0) #pass back zeros if not compute_masks
|
| 312 |
+
|
| 313 |
+
masks, dP, cellprob = masks.squeeze(), dP.squeeze(), cellprob.squeeze()
|
| 314 |
+
|
| 315 |
+
# undo resizing:
|
| 316 |
+
if image_scaling is not None or anisotropy is not None:
|
| 317 |
+
|
| 318 |
+
dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0) # works for 2 or 3D:
|
| 319 |
+
cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0)
|
| 320 |
+
|
| 321 |
+
if do_3D:
|
| 322 |
+
if compute_masks:
|
| 323 |
+
# Rescale xy then xz:
|
| 324 |
+
masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
|
| 325 |
+
masks = masks.transpose(1, 0, 2)
|
| 326 |
+
masks = transforms.resize_image(masks, Ly=Lz_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
|
| 327 |
+
masks = masks.transpose(1, 0, 2)
|
| 328 |
+
|
| 329 |
+
else:
|
| 330 |
+
# 2D or 3D stitching case:
|
| 331 |
+
if compute_masks:
|
| 332 |
+
masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST)
|
| 333 |
+
|
| 334 |
+
return masks, [plot.dx_to_circ(dP), dP, cellprob], styles
|
| 335 |
+
|
| 336 |
+
|
| 337 |
+
def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
|
| 338 |
+
"""
|
| 339 |
+
Resize cellprob array to specified dimensions for either 2D or 3D.
|
| 340 |
+
|
| 341 |
+
Parameters:
|
| 342 |
+
prob (numpy.ndarray): The cellprobs to resize, either in 2D or 3D. Returns the same ndim as provided.
|
| 343 |
+
to_y_size (int): The target size along the Y-axis.
|
| 344 |
+
to_x_size (int): The target size along the X-axis.
|
| 345 |
+
to_z_size (int, optional): The target size along the Z-axis. Required
|
| 346 |
+
for 3D cellprobs.
|
| 347 |
+
|
| 348 |
+
Returns:
|
| 349 |
+
numpy.ndarray: The resized cellprobs array with the same number of dimensions
|
| 350 |
+
as the input.
|
| 351 |
+
|
| 352 |
+
Raises:
|
| 353 |
+
ValueError: If the input cellprobs array does not have 3 or 4 dimensions.
|
| 354 |
+
"""
|
| 355 |
+
prob_shape = prob.shape
|
| 356 |
+
prob = prob.squeeze()
|
| 357 |
+
squeeze_happened = prob.shape != prob_shape
|
| 358 |
+
prob_shape = np.array(prob_shape)
|
| 359 |
+
|
| 360 |
+
if prob.ndim == 2:
|
| 361 |
+
# 2D case:
|
| 362 |
+
prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True)
|
| 363 |
+
if squeeze_happened:
|
| 364 |
+
prob = np.expand_dims(prob, int(np.argwhere(prob_shape == 1))) # add back empty axis for compatibility
|
| 365 |
+
elif prob.ndim == 3:
|
| 366 |
+
# 3D case:
|
| 367 |
+
prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True)
|
| 368 |
+
prob = prob.transpose(1, 0, 2)
|
| 369 |
+
prob = transforms.resize_image(prob, Ly=to_z_size, Lx=to_x_size, no_channels=True)
|
| 370 |
+
prob = prob.transpose(1, 0, 2)
|
| 371 |
+
else:
|
| 372 |
+
raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 2 or 3, prob shape: {prob.shape}')
|
| 373 |
+
|
| 374 |
+
return prob
|
| 375 |
+
|
| 376 |
+
|
| 377 |
+
def _resize_gradients(self, grads: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray:
|
| 378 |
+
"""
|
| 379 |
+
Resize gradient arrays to specified dimensions for either 2D or 3D gradients.
|
| 380 |
+
|
| 381 |
+
Parameters:
|
| 382 |
+
grads (np.ndarray): The gradients to resize, either in 2D or 3D. Returns the same ndim as provided.
|
| 383 |
+
to_y_size (int): The target size along the Y-axis.
|
| 384 |
+
to_x_size (int): The target size along the X-axis.
|
| 385 |
+
to_z_size (int, optional): The target size along the Z-axis. Required
|
| 386 |
+
for 3D gradients.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
numpy.ndarray: The resized gradient array with the same number of dimensions
|
| 390 |
+
as the input.
|
| 391 |
+
|
| 392 |
+
Raises:
|
| 393 |
+
ValueError: If the input gradient array does not have 3 or 4 dimensions.
|
| 394 |
+
"""
|
| 395 |
+
grads_shape = grads.shape
|
| 396 |
+
grads = grads.squeeze()
|
| 397 |
+
squeeze_happened = grads.shape != grads_shape
|
| 398 |
+
grads_shape = np.array(grads_shape)
|
| 399 |
+
|
| 400 |
+
if grads.ndim == 3:
|
| 401 |
+
# 2D case, with XY flows in 2 channels:
|
| 402 |
+
grads = np.moveaxis(grads, 0, -1) # Put gradients last
|
| 403 |
+
grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False)
|
| 404 |
+
grads = np.moveaxis(grads, -1, 0) # Put gradients first
|
| 405 |
+
|
| 406 |
+
if squeeze_happened:
|
| 407 |
+
grads = np.expand_dims(grads, int(np.argwhere(grads_shape == 1))) # add back empty axis for compatibility
|
| 408 |
+
elif grads.ndim == 4:
|
| 409 |
+
# dP has gradients that can be treated as channels:
|
| 410 |
+
grads = grads.transpose(1, 2, 3, 0) # move gradients last:
|
| 411 |
+
grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False)
|
| 412 |
+
grads = grads.transpose(1, 0, 2, 3) # switch axes to resize again
|
| 413 |
+
grads = transforms.resize_image(grads, Ly=to_z_size, Lx=to_x_size, no_channels=False)
|
| 414 |
+
grads = grads.transpose(3, 1, 0, 2) # undo transposition
|
| 415 |
+
else:
|
| 416 |
+
raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 3 or 4, grads shape: {grads.shape}')
|
| 417 |
+
|
| 418 |
+
return grads
|
| 419 |
+
|
| 420 |
+
|
| 421 |
+
def _run_net(self, x, feat=None,
|
| 422 |
+
augment=False,
|
| 423 |
+
batch_size=8, tile_overlap=0.1,
|
| 424 |
+
bsize=224, anisotropy=1.0, do_3D=False):
|
| 425 |
+
""" run network on image x """
|
| 426 |
+
tic = time.time()
|
| 427 |
+
shape = x.shape
|
| 428 |
+
nimg = shape[0]
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
if do_3D:
|
| 432 |
+
Lz, Ly, Lx = shape[:-1]
|
| 433 |
+
if anisotropy is not None and anisotropy != 1.0:
|
| 434 |
+
models_logger.info(f"resizing 3D image with anisotropy={anisotropy}")
|
| 435 |
+
x = transforms.resize_image(x.transpose(1,0,2,3),
|
| 436 |
+
Ly=int(Lz*anisotropy),
|
| 437 |
+
Lx=int(Lx)).transpose(1,0,2,3)
|
| 438 |
+
yf, styles = run_3D(self.net, x,
|
| 439 |
+
batch_size=batch_size, augment=augment,
|
| 440 |
+
tile_overlap=tile_overlap,
|
| 441 |
+
bsize=bsize
|
| 442 |
+
)
|
| 443 |
+
cellprob = yf[..., -1]
|
| 444 |
+
dP = yf[..., :-1].transpose((3, 0, 1, 2))
|
| 445 |
+
else:
|
| 446 |
+
yf, styles = run_net(self.net, x, feat=feat, bsize=bsize, augment=augment,
|
| 447 |
+
batch_size=batch_size,
|
| 448 |
+
tile_overlap=tile_overlap,
|
| 449 |
+
)
|
| 450 |
+
cellprob = yf[..., -1]
|
| 451 |
+
dP = yf[..., -3:-1].transpose((3, 0, 1, 2))
|
| 452 |
+
if yf.shape[-1] > 3:
|
| 453 |
+
styles = yf[..., :-3]
|
| 454 |
+
|
| 455 |
+
styles = styles.squeeze()
|
| 456 |
+
|
| 457 |
+
net_time = time.time() - tic
|
| 458 |
+
if nimg > 1:
|
| 459 |
+
models_logger.info("network run in %2.2fs" % (net_time))
|
| 460 |
+
|
| 461 |
+
return dP, cellprob, styles
|
| 462 |
+
|
| 463 |
+
def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_threshold=0.0,
|
| 464 |
+
min_size=15, max_size_fraction=0.4, niter=None,
|
| 465 |
+
do_3D=False, stitch_threshold=0.0):
|
| 466 |
+
""" compute masks from flows and cell probability """
|
| 467 |
+
changed_device_from = None
|
| 468 |
+
if self.device.type == "mps" and do_3D:
|
| 469 |
+
models_logger.warning("MPS does not support 3D post-processing, switching to CPU")
|
| 470 |
+
self.device = torch.device("cpu")
|
| 471 |
+
changed_device_from = "mps"
|
| 472 |
+
Lz, Ly, Lx = shape[:3]
|
| 473 |
+
tic = time.time()
|
| 474 |
+
if do_3D:
|
| 475 |
+
masks = dynamics.resize_and_compute_masks(
|
| 476 |
+
dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold,
|
| 477 |
+
flow_threshold=flow_threshold, do_3D=do_3D,
|
| 478 |
+
min_size=min_size, max_size_fraction=max_size_fraction,
|
| 479 |
+
resize=shape[:3] if (np.array(dP.shape[-3:])!=np.array(shape[:3])).sum()
|
| 480 |
+
else None,
|
| 481 |
+
device=self.device)
|
| 482 |
+
else:
|
| 483 |
+
nimg = shape[0]
|
| 484 |
+
Ly0, Lx0 = cellprob[0].shape
|
| 485 |
+
resize = None if Ly0==Ly and Lx0==Lx else [Ly, Lx]
|
| 486 |
+
tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO)
|
| 487 |
+
iterator = trange(nimg, file=tqdm_out,
|
| 488 |
+
mininterval=30) if nimg > 1 else range(nimg)
|
| 489 |
+
for i in iterator:
|
| 490 |
+
# turn off min_size for 3D stitching
|
| 491 |
+
min_size0 = min_size if stitch_threshold == 0 or nimg == 1 else -1
|
| 492 |
+
outputs = dynamics.resize_and_compute_masks(
|
| 493 |
+
dP[:, i], cellprob[i],
|
| 494 |
+
niter=niter, cellprob_threshold=cellprob_threshold,
|
| 495 |
+
flow_threshold=flow_threshold, resize=resize,
|
| 496 |
+
min_size=min_size0, max_size_fraction=max_size_fraction,
|
| 497 |
+
device=self.device)
|
| 498 |
+
if i==0 and nimg > 1:
|
| 499 |
+
masks = np.zeros((nimg, shape[1], shape[2]), outputs.dtype)
|
| 500 |
+
if nimg > 1:
|
| 501 |
+
masks[i] = outputs
|
| 502 |
+
else:
|
| 503 |
+
masks = outputs
|
| 504 |
+
|
| 505 |
+
if stitch_threshold > 0 and nimg > 1:
|
| 506 |
+
models_logger.info(
|
| 507 |
+
f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks"
|
| 508 |
+
)
|
| 509 |
+
masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold)
|
| 510 |
+
masks = utils.fill_holes_and_remove_small_masks(
|
| 511 |
+
masks, min_size=min_size)
|
| 512 |
+
elif nimg > 1:
|
| 513 |
+
models_logger.warning(
|
| 514 |
+
"3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only"
|
| 515 |
+
)
|
| 516 |
+
|
| 517 |
+
flow_time = time.time() - tic
|
| 518 |
+
if shape[0] > 1:
|
| 519 |
+
models_logger.info("masks created in %2.2fs" % (flow_time))
|
| 520 |
+
|
| 521 |
+
if changed_device_from is not None:
|
| 522 |
+
models_logger.info("switching back to device %s" % self.device)
|
| 523 |
+
self.device = torch.device(changed_device_from)
|
| 524 |
+
return masks
|
models/seg_post_model/cellpose/plot.py
ADDED
|
@@ -0,0 +1,281 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
import numpy as np
|
| 6 |
+
import cv2
|
| 7 |
+
from scipy.ndimage import gaussian_filter
|
| 8 |
+
from . import utils, io, transforms
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import matplotlib
|
| 12 |
+
MATPLOTLIB_ENABLED = True
|
| 13 |
+
except:
|
| 14 |
+
MATPLOTLIB_ENABLED = False
|
| 15 |
+
|
| 16 |
+
try:
|
| 17 |
+
from skimage import color
|
| 18 |
+
from skimage.segmentation import find_boundaries
|
| 19 |
+
SKIMAGE_ENABLED = True
|
| 20 |
+
except:
|
| 21 |
+
SKIMAGE_ENABLED = False
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
# modified to use sinebow color
|
| 25 |
+
def dx_to_circ(dP):
|
| 26 |
+
"""Converts the optic flow representation to a circular color representation.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
dP (ndarray): Flow field components [dy, dx].
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
ndarray: The circular color representation of the optic flow.
|
| 33 |
+
|
| 34 |
+
"""
|
| 35 |
+
mag = 255 * np.clip(transforms.normalize99(np.sqrt(np.sum(dP**2, axis=0))), 0, 1.)
|
| 36 |
+
angles = np.arctan2(dP[1], dP[0]) + np.pi
|
| 37 |
+
a = 2
|
| 38 |
+
mag /= a
|
| 39 |
+
rgb = np.zeros((*dP.shape[1:], 3), "uint8")
|
| 40 |
+
rgb[..., 0] = np.clip(mag * (np.cos(angles) + 1), 0, 255).astype("uint8")
|
| 41 |
+
rgb[..., 1] = np.clip(mag * (np.cos(angles + 2 * np.pi / 3) + 1), 0, 255).astype("uint8")
|
| 42 |
+
rgb[..., 2] = np.clip(mag * (np.cos(angles + 4 * np.pi / 3) + 1), 0, 255).astype("uint8")
|
| 43 |
+
|
| 44 |
+
return rgb
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None):
|
| 48 |
+
"""Plot segmentation results (like on website).
|
| 49 |
+
|
| 50 |
+
Can save each panel of figure with file_name option. Use channels option if
|
| 51 |
+
img input is not an RGB image with 3 channels.
|
| 52 |
+
|
| 53 |
+
Args:
|
| 54 |
+
fig (matplotlib.pyplot.figure): Figure in which to make plot.
|
| 55 |
+
img (ndarray): 2D or 3D array. Image input into cellpose.
|
| 56 |
+
maski (int, ndarray): For image k, masks[k] output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels.
|
| 57 |
+
flowi (int, ndarray): For image k, flows[k][0] output from Cellpose.eval (RGB of flows).
|
| 58 |
+
channels (list of int, optional): Channels used to run Cellpose, no need to use if image is RGB. Defaults to [0, 0].
|
| 59 |
+
file_name (str, optional): File name of image. If file_name is not None, figure panels are saved. Defaults to None.
|
| 60 |
+
seg_norm (bool, optional): Improve cell visibility under labels. Defaults to False.
|
| 61 |
+
"""
|
| 62 |
+
if not MATPLOTLIB_ENABLED:
|
| 63 |
+
raise ImportError(
|
| 64 |
+
"matplotlib not installed, install with 'pip install matplotlib'")
|
| 65 |
+
ax = fig.add_subplot(1, 4, 1)
|
| 66 |
+
img0 = img.copy()
|
| 67 |
+
|
| 68 |
+
if img0.shape[0] < 4:
|
| 69 |
+
img0 = np.transpose(img0, (1, 2, 0))
|
| 70 |
+
if img0.shape[-1] < 3 or img0.ndim < 3:
|
| 71 |
+
img0 = image_to_rgb(img0, channels=channels)
|
| 72 |
+
else:
|
| 73 |
+
if img0.max() <= 50.0:
|
| 74 |
+
img0 = np.uint8(np.clip(img0, 0, 1) * 255)
|
| 75 |
+
ax.imshow(img0)
|
| 76 |
+
ax.set_title("original image")
|
| 77 |
+
ax.axis("off")
|
| 78 |
+
|
| 79 |
+
outlines = utils.masks_to_outlines(maski)
|
| 80 |
+
|
| 81 |
+
overlay = mask_overlay(img0, maski)
|
| 82 |
+
|
| 83 |
+
ax = fig.add_subplot(1, 4, 2)
|
| 84 |
+
outX, outY = np.nonzero(outlines)
|
| 85 |
+
imgout = img0.copy()
|
| 86 |
+
imgout[outX, outY] = np.array([255, 0, 0]) # pure red
|
| 87 |
+
|
| 88 |
+
ax.imshow(imgout)
|
| 89 |
+
ax.set_title("predicted outlines")
|
| 90 |
+
ax.axis("off")
|
| 91 |
+
|
| 92 |
+
ax = fig.add_subplot(1, 4, 3)
|
| 93 |
+
ax.imshow(overlay)
|
| 94 |
+
ax.set_title("predicted masks")
|
| 95 |
+
ax.axis("off")
|
| 96 |
+
|
| 97 |
+
ax = fig.add_subplot(1, 4, 4)
|
| 98 |
+
ax.imshow(flowi)
|
| 99 |
+
ax.set_title("predicted cell pose")
|
| 100 |
+
ax.axis("off")
|
| 101 |
+
|
| 102 |
+
if file_name is not None:
|
| 103 |
+
save_path = os.path.splitext(file_name)[0]
|
| 104 |
+
io.imsave(save_path + "_overlay.jpg", overlay)
|
| 105 |
+
io.imsave(save_path + "_outlines.jpg", imgout)
|
| 106 |
+
io.imsave(save_path + "_flows.jpg", flowi)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def mask_rgb(masks, colors=None):
|
| 110 |
+
"""Masks in random RGB colors.
|
| 111 |
+
|
| 112 |
+
Args:
|
| 113 |
+
masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
|
| 114 |
+
colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range.
|
| 115 |
+
|
| 116 |
+
Returns:
|
| 117 |
+
RGB (uint8, 3D array): Array of masks overlaid on grayscale image.
|
| 118 |
+
"""
|
| 119 |
+
if colors is not None:
|
| 120 |
+
if colors.max() > 1:
|
| 121 |
+
colors = np.float32(colors)
|
| 122 |
+
colors /= 255
|
| 123 |
+
colors = utils.rgb_to_hsv(colors)
|
| 124 |
+
|
| 125 |
+
HSV = np.zeros((masks.shape[0], masks.shape[1], 3), np.float32)
|
| 126 |
+
HSV[:, :, 2] = 1.0
|
| 127 |
+
for n in range(int(masks.max())):
|
| 128 |
+
ipix = (masks == n + 1).nonzero()
|
| 129 |
+
if colors is None:
|
| 130 |
+
HSV[ipix[0], ipix[1], 0] = np.random.rand()
|
| 131 |
+
else:
|
| 132 |
+
HSV[ipix[0], ipix[1], 0] = colors[n, 0]
|
| 133 |
+
HSV[ipix[0], ipix[1], 1] = np.random.rand() * 0.5 + 0.5
|
| 134 |
+
HSV[ipix[0], ipix[1], 2] = np.random.rand() * 0.5 + 0.5
|
| 135 |
+
RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8)
|
| 136 |
+
return RGB
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
def mask_overlay(img, masks, colors=None):
|
| 140 |
+
"""Overlay masks on image (set image to grayscale).
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
img (int or float, 2D or 3D array): Image of size [Ly x Lx (x nchan)].
|
| 144 |
+
masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels.
|
| 145 |
+
colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range.
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
RGB (uint8, 3D array): Array of masks overlaid on grayscale image.
|
| 149 |
+
"""
|
| 150 |
+
if colors is not None:
|
| 151 |
+
if colors.max() > 1:
|
| 152 |
+
colors = np.float32(colors)
|
| 153 |
+
colors /= 255
|
| 154 |
+
colors = utils.rgb_to_hsv(colors)
|
| 155 |
+
if img.ndim > 2:
|
| 156 |
+
img = img.astype(np.float32).mean(axis=-1)
|
| 157 |
+
else:
|
| 158 |
+
img = img.astype(np.float32)
|
| 159 |
+
|
| 160 |
+
HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32)
|
| 161 |
+
HSV[:, :, 2] = np.clip((img / 255. if img.max() > 1 else img) * 1.5, 0, 1)
|
| 162 |
+
hues = np.linspace(0, 1, masks.max() + 1)[np.random.permutation(masks.max())]
|
| 163 |
+
for n in range(int(masks.max())):
|
| 164 |
+
ipix = (masks == n + 1).nonzero()
|
| 165 |
+
if colors is None:
|
| 166 |
+
HSV[ipix[0], ipix[1], 0] = hues[n]
|
| 167 |
+
else:
|
| 168 |
+
HSV[ipix[0], ipix[1], 0] = colors[n, 0]
|
| 169 |
+
HSV[ipix[0], ipix[1], 1] = 1.0
|
| 170 |
+
RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8)
|
| 171 |
+
return RGB
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def image_to_rgb(img0, channels=[0, 0]):
|
| 175 |
+
"""Converts image from 2 x Ly x Lx or Ly x Lx x 2 to RGB Ly x Lx x 3.
|
| 176 |
+
|
| 177 |
+
Args:
|
| 178 |
+
img0 (ndarray): Input image of shape 2 x Ly x Lx or Ly x Lx x 2.
|
| 179 |
+
|
| 180 |
+
Returns:
|
| 181 |
+
ndarray: RGB image of shape Ly x Lx x 3.
|
| 182 |
+
|
| 183 |
+
"""
|
| 184 |
+
img = img0.copy()
|
| 185 |
+
img = img.astype(np.float32)
|
| 186 |
+
if img.ndim < 3:
|
| 187 |
+
img = img[:, :, np.newaxis]
|
| 188 |
+
if img.shape[0] < 5:
|
| 189 |
+
img = np.transpose(img, (1, 2, 0))
|
| 190 |
+
if channels[0] == 0:
|
| 191 |
+
img = img.mean(axis=-1)[:, :, np.newaxis]
|
| 192 |
+
for i in range(img.shape[-1]):
|
| 193 |
+
if np.ptp(img[:, :, i]) > 0:
|
| 194 |
+
img[:, :, i] = np.clip(transforms.normalize99(img[:, :, i]), 0, 1)
|
| 195 |
+
img[:, :, i] = np.clip(img[:, :, i], 0, 1)
|
| 196 |
+
img *= 255
|
| 197 |
+
img = np.uint8(img)
|
| 198 |
+
RGB = np.zeros((img.shape[0], img.shape[1], 3), np.uint8)
|
| 199 |
+
if img.shape[-1] == 1:
|
| 200 |
+
RGB = np.tile(img, (1, 1, 3))
|
| 201 |
+
else:
|
| 202 |
+
RGB[:, :, channels[0] - 1] = img[:, :, 0]
|
| 203 |
+
if channels[1] > 0:
|
| 204 |
+
RGB[:, :, channels[1] - 1] = img[:, :, 1]
|
| 205 |
+
return RGB
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def interesting_patch(mask, bsize=130):
|
| 209 |
+
"""
|
| 210 |
+
Get patch of size bsize x bsize with most masks.
|
| 211 |
+
|
| 212 |
+
Args:
|
| 213 |
+
mask (ndarray): Input mask.
|
| 214 |
+
bsize (int): Size of the patch.
|
| 215 |
+
|
| 216 |
+
Returns:
|
| 217 |
+
tuple: Patch coordinates (y, x).
|
| 218 |
+
|
| 219 |
+
"""
|
| 220 |
+
Ly, Lx = mask.shape
|
| 221 |
+
m = np.float32(mask > 0)
|
| 222 |
+
m = gaussian_filter(m, bsize / 2)
|
| 223 |
+
y, x = np.unravel_index(np.argmax(m), m.shape)
|
| 224 |
+
ycent = max(bsize // 2, min(y, Ly - bsize // 2))
|
| 225 |
+
xcent = max(bsize // 2, min(x, Lx - bsize // 2))
|
| 226 |
+
patch = [
|
| 227 |
+
np.arange(ycent - bsize // 2, ycent + bsize // 2, 1, int),
|
| 228 |
+
np.arange(xcent - bsize // 2, xcent + bsize // 2, 1, int)
|
| 229 |
+
]
|
| 230 |
+
return patch
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
def disk(med, r, Ly, Lx):
|
| 234 |
+
"""Returns the pixels of a disk with a given radius and center.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
med (tuple): The center coordinates of the disk.
|
| 238 |
+
r (float): The radius of the disk.
|
| 239 |
+
Ly (int): The height of the image.
|
| 240 |
+
Lx (int): The width of the image.
|
| 241 |
+
|
| 242 |
+
Returns:
|
| 243 |
+
tuple: A tuple containing the y and x coordinates of the pixels within the disk.
|
| 244 |
+
|
| 245 |
+
"""
|
| 246 |
+
yy, xx = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int),
|
| 247 |
+
indexing="ij")
|
| 248 |
+
inds = ((yy - med[0])**2 + (xx - med[1])**2)**0.5 <= r
|
| 249 |
+
y = yy[inds].flatten()
|
| 250 |
+
x = xx[inds].flatten()
|
| 251 |
+
return y, x
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def outline_view(img0, maski, color=[1, 0, 0], mode="inner"):
|
| 255 |
+
"""
|
| 256 |
+
Generates a red outline overlay onto the image.
|
| 257 |
+
|
| 258 |
+
Args:
|
| 259 |
+
img0 (numpy.ndarray): The input image.
|
| 260 |
+
maski (numpy.ndarray): The mask representing the region of interest.
|
| 261 |
+
color (list, optional): The color of the outline overlay. Defaults to [1, 0, 0] (red).
|
| 262 |
+
mode (str, optional): The mode for generating the outline. Defaults to "inner".
|
| 263 |
+
|
| 264 |
+
Returns:
|
| 265 |
+
numpy.ndarray: The image with the red outline overlay.
|
| 266 |
+
|
| 267 |
+
"""
|
| 268 |
+
if img0.ndim == 2:
|
| 269 |
+
img0 = np.stack([img0] * 3, axis=-1)
|
| 270 |
+
elif img0.ndim != 3:
|
| 271 |
+
raise ValueError("img0 not right size (must have ndim 2 or 3)")
|
| 272 |
+
|
| 273 |
+
if SKIMAGE_ENABLED:
|
| 274 |
+
outlines = find_boundaries(maski, mode=mode)
|
| 275 |
+
else:
|
| 276 |
+
outlines = utils.masks_to_outlines(maski, mode=mode)
|
| 277 |
+
outY, outX = np.nonzero(outlines)
|
| 278 |
+
imgout = img0.copy()
|
| 279 |
+
imgout[outY, outX] = np.array(color)
|
| 280 |
+
|
| 281 |
+
return imgout
|
models/seg_post_model/cellpose/transforms.py
ADDED
|
@@ -0,0 +1,1261 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
|
| 6 |
+
import cv2
|
| 7 |
+
import numpy as np
|
| 8 |
+
import torch
|
| 9 |
+
from scipy.ndimage import gaussian_filter1d
|
| 10 |
+
from torch.fft import fft2, fftshift, ifft2
|
| 11 |
+
|
| 12 |
+
transforms_logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
def _taper_mask(ly=224, lx=224, sig=7.5):
|
| 16 |
+
"""
|
| 17 |
+
Generate a taper mask.
|
| 18 |
+
|
| 19 |
+
Args:
|
| 20 |
+
ly (int): The height of the mask. Default is 224.
|
| 21 |
+
lx (int): The width of the mask. Default is 224.
|
| 22 |
+
sig (float): The sigma value for the tapering function. Default is 7.5.
|
| 23 |
+
|
| 24 |
+
Returns:
|
| 25 |
+
numpy.ndarray: The taper mask.
|
| 26 |
+
|
| 27 |
+
"""
|
| 28 |
+
bsize = max(224, max(ly, lx))
|
| 29 |
+
xm = np.arange(bsize)
|
| 30 |
+
xm = np.abs(xm - xm.mean())
|
| 31 |
+
mask = 1 / (1 + np.exp((xm - (bsize / 2 - 20)) / sig))
|
| 32 |
+
mask = mask * mask[:, np.newaxis]
|
| 33 |
+
mask = mask[bsize // 2 - ly // 2:bsize // 2 + ly // 2 + ly % 2,
|
| 34 |
+
bsize // 2 - lx // 2:bsize // 2 + lx // 2 + lx % 2]
|
| 35 |
+
return mask
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def unaugment_tiles(y):
|
| 39 |
+
"""Reverse test-time augmentations for averaging (includes flipping of flowsY and flowsX).
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
y (float32): Array of shape (ntiles_y, ntiles_x, chan, Ly, Lx) where chan = (flowsY, flowsX, cell prob).
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
float32: Array of shape (ntiles_y, ntiles_x, chan, Ly, Lx).
|
| 46 |
+
|
| 47 |
+
"""
|
| 48 |
+
for j in range(y.shape[0]):
|
| 49 |
+
for i in range(y.shape[1]):
|
| 50 |
+
if j % 2 == 0 and i % 2 == 1:
|
| 51 |
+
y[j, i] = y[j, i, :, ::-1, :]
|
| 52 |
+
y[j, i, 0] *= -1
|
| 53 |
+
elif j % 2 == 1 and i % 2 == 0:
|
| 54 |
+
y[j, i] = y[j, i, :, :, ::-1]
|
| 55 |
+
y[j, i, 1] *= -1
|
| 56 |
+
elif j % 2 == 1 and i % 2 == 1:
|
| 57 |
+
y[j, i] = y[j, i, :, ::-1, ::-1]
|
| 58 |
+
y[j, i, 0] *= -1
|
| 59 |
+
y[j, i, 1] *= -1
|
| 60 |
+
return y
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def average_tiles(y, ysub, xsub, Ly, Lx):
|
| 64 |
+
"""
|
| 65 |
+
Average the results of the network over tiles.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
y (float): Output of cellpose network for each tile. Shape: [ntiles x nclasses x bsize x bsize]
|
| 69 |
+
ysub (list): List of arrays with start and end of tiles in Y of length ntiles
|
| 70 |
+
xsub (list): List of arrays with start and end of tiles in X of length ntiles
|
| 71 |
+
Ly (int): Size of pre-tiled image in Y (may be larger than original image if image size is less than bsize)
|
| 72 |
+
Lx (int): Size of pre-tiled image in X (may be larger than original image if image size is less than bsize)
|
| 73 |
+
|
| 74 |
+
Returns:
|
| 75 |
+
yf (float32): Network output averaged over tiles. Shape: [nclasses x Ly x Lx]
|
| 76 |
+
"""
|
| 77 |
+
Navg = np.zeros((Ly, Lx))
|
| 78 |
+
yf = np.zeros((y.shape[1], Ly, Lx), np.float32)
|
| 79 |
+
# taper edges of tiles
|
| 80 |
+
mask = _taper_mask(ly=y.shape[-2], lx=y.shape[-1])
|
| 81 |
+
for j in range(len(ysub)):
|
| 82 |
+
yf[:, ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += y[j] * mask
|
| 83 |
+
Navg[ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += mask
|
| 84 |
+
yf /= Navg
|
| 85 |
+
return yf
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def make_tiles(imgi, bsize=224, augment=False, tile_overlap=0.1):
|
| 89 |
+
"""Make tiles of image to run at test-time.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
imgi (np.ndarray): Array of shape (nchan, Ly, Lx) representing the input image.
|
| 93 |
+
bsize (int, optional): Size of tiles. Defaults to 224.
|
| 94 |
+
augment (bool, optional): Whether to flip tiles and set tile_overlap=2. Defaults to False.
|
| 95 |
+
tile_overlap (float, optional): Fraction of overlap of tiles. Defaults to 0.1.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
A tuple containing (IMG, ysub, xsub, Ly, Lx):
|
| 99 |
+
IMG (np.ndarray): Array of shape (ntiles, nchan, bsize, bsize) representing the tiles.
|
| 100 |
+
ysub (list): List of arrays with start and end of tiles in Y of length ntiles.
|
| 101 |
+
xsub (list): List of arrays with start and end of tiles in X of length ntiles.
|
| 102 |
+
Ly (int): Height of the input image.
|
| 103 |
+
Lx (int): Width of the input image.
|
| 104 |
+
"""
|
| 105 |
+
nchan, Ly, Lx = imgi.shape
|
| 106 |
+
if augment:
|
| 107 |
+
bsize = np.int32(bsize)
|
| 108 |
+
# pad if image smaller than bsize
|
| 109 |
+
if Ly < bsize:
|
| 110 |
+
imgi = np.concatenate((imgi, np.zeros((nchan, bsize - Ly, Lx))), axis=1)
|
| 111 |
+
Ly = bsize
|
| 112 |
+
if Lx < bsize:
|
| 113 |
+
imgi = np.concatenate((imgi, np.zeros((nchan, Ly, bsize - Lx))), axis=2)
|
| 114 |
+
Ly, Lx = imgi.shape[-2:]
|
| 115 |
+
|
| 116 |
+
# tiles overlap by half of tile size
|
| 117 |
+
ny = max(2, int(np.ceil(2. * Ly / bsize)))
|
| 118 |
+
nx = max(2, int(np.ceil(2. * Lx / bsize)))
|
| 119 |
+
ystart = np.linspace(0, Ly - bsize, ny).astype(int)
|
| 120 |
+
xstart = np.linspace(0, Lx - bsize, nx).astype(int)
|
| 121 |
+
|
| 122 |
+
ysub = []
|
| 123 |
+
xsub = []
|
| 124 |
+
|
| 125 |
+
# flip tiles so that overlapping segments are processed in rotation
|
| 126 |
+
IMG = np.zeros((len(ystart), len(xstart), nchan, bsize, bsize), np.float32)
|
| 127 |
+
for j in range(len(ystart)):
|
| 128 |
+
for i in range(len(xstart)):
|
| 129 |
+
ysub.append([ystart[j], ystart[j] + bsize])
|
| 130 |
+
xsub.append([xstart[i], xstart[i] + bsize])
|
| 131 |
+
IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]]
|
| 132 |
+
# flip tiles to allow for augmentation of overlapping segments
|
| 133 |
+
if j % 2 == 0 and i % 2 == 1:
|
| 134 |
+
IMG[j, i] = IMG[j, i, :, ::-1, :]
|
| 135 |
+
elif j % 2 == 1 and i % 2 == 0:
|
| 136 |
+
IMG[j, i] = IMG[j, i, :, :, ::-1]
|
| 137 |
+
elif j % 2 == 1 and i % 2 == 1:
|
| 138 |
+
IMG[j, i] = IMG[j, i, :, ::-1, ::-1]
|
| 139 |
+
else:
|
| 140 |
+
tile_overlap = min(0.5, max(0.05, tile_overlap))
|
| 141 |
+
bsizeY, bsizeX = min(bsize, Ly), min(bsize, Lx)
|
| 142 |
+
bsizeY = np.int32(bsizeY)
|
| 143 |
+
bsizeX = np.int32(bsizeX)
|
| 144 |
+
# tiles overlap by 10% tile size
|
| 145 |
+
ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize))
|
| 146 |
+
nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize))
|
| 147 |
+
ystart = np.linspace(0, Ly - bsizeY, ny).astype(int)
|
| 148 |
+
xstart = np.linspace(0, Lx - bsizeX, nx).astype(int)
|
| 149 |
+
|
| 150 |
+
ysub = []
|
| 151 |
+
xsub = []
|
| 152 |
+
IMG = np.zeros((len(ystart), len(xstart), nchan, bsizeY, bsizeX), np.float32)
|
| 153 |
+
for j in range(len(ystart)):
|
| 154 |
+
for i in range(len(xstart)):
|
| 155 |
+
ysub.append([ystart[j], ystart[j] + bsizeY])
|
| 156 |
+
xsub.append([xstart[i], xstart[i] + bsizeX])
|
| 157 |
+
IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]]
|
| 158 |
+
|
| 159 |
+
return IMG, ysub, xsub, Ly, Lx
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def normalize99(Y, lower=1, upper=99, copy=True, downsample=False):
|
| 163 |
+
"""
|
| 164 |
+
Normalize the image so that 0.0 corresponds to the 1st percentile and 1.0 corresponds to the 99th percentile.
|
| 165 |
+
|
| 166 |
+
Args:
|
| 167 |
+
Y (ndarray): The input image (for downsample, use [Ly x Lx] or [Lz x Ly x Lx]).
|
| 168 |
+
lower (int, optional): The lower percentile. Defaults to 1.
|
| 169 |
+
upper (int, optional): The upper percentile. Defaults to 99.
|
| 170 |
+
copy (bool, optional): Whether to create a copy of the input image. Defaults to True.
|
| 171 |
+
downsample (bool, optional): Whether to downsample image to compute percentiles. Defaults to False.
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
ndarray: The normalized image.
|
| 175 |
+
"""
|
| 176 |
+
X = Y.copy() if copy else Y
|
| 177 |
+
X = X.astype("float32") if X.dtype!="float64" and X.dtype!="float32" else X
|
| 178 |
+
if downsample and X.size > 224**3:
|
| 179 |
+
nskip = [max(1, X.shape[i] // 224) for i in range(X.ndim)]
|
| 180 |
+
nskip[0] = max(1, X.shape[0] // 50) if X.ndim == 3 else nskip[0]
|
| 181 |
+
slc = tuple([slice(0, X.shape[i], nskip[i]) for i in range(X.ndim)])
|
| 182 |
+
x01 = np.percentile(X[slc], lower)
|
| 183 |
+
x99 = np.percentile(X[slc], upper)
|
| 184 |
+
else:
|
| 185 |
+
x01 = np.percentile(X, lower)
|
| 186 |
+
x99 = np.percentile(X, upper)
|
| 187 |
+
if x99 - x01 > 1e-3:
|
| 188 |
+
X -= x01
|
| 189 |
+
X /= (x99 - x01)
|
| 190 |
+
else:
|
| 191 |
+
X[:] = 0
|
| 192 |
+
return X
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
def normalize99_tile(img, blocksize=100, lower=1., upper=99., tile_overlap=0.1,
|
| 196 |
+
norm3D=False, smooth3D=1, is3D=False):
|
| 197 |
+
"""Compute normalization like normalize99 function but in tiles.
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
img (numpy.ndarray): Array of shape (Lz x) Ly x Lx (x nchan) containing the image.
|
| 201 |
+
blocksize (float, optional): Size of tiles. Defaults to 100.
|
| 202 |
+
lower (float, optional): Lower percentile for normalization. Defaults to 1.0.
|
| 203 |
+
upper (float, optional): Upper percentile for normalization. Defaults to 99.0.
|
| 204 |
+
tile_overlap (float, optional): Fraction of overlap of tiles. Defaults to 0.1.
|
| 205 |
+
norm3D (bool, optional): Use same tiled normalization for each z-plane. Defaults to False.
|
| 206 |
+
smooth3D (int, optional): Smoothing factor for 3D normalization. Defaults to 1.
|
| 207 |
+
is3D (bool, optional): Set to True if image is a 3D stack. Defaults to False.
|
| 208 |
+
|
| 209 |
+
Returns:
|
| 210 |
+
numpy.ndarray: Normalized image array of shape (Lz x) Ly x Lx (x nchan).
|
| 211 |
+
"""
|
| 212 |
+
is1c = True if img.ndim == 2 or (is3D and img.ndim == 3) else False
|
| 213 |
+
is3D = True if img.ndim > 3 or (is3D and img.ndim == 3) else False
|
| 214 |
+
img = img[..., np.newaxis] if is1c else img
|
| 215 |
+
img = img[np.newaxis, ...] if img.ndim == 3 else img
|
| 216 |
+
Lz, Ly, Lx, nchan = img.shape
|
| 217 |
+
|
| 218 |
+
tile_overlap = min(0.5, max(0.05, tile_overlap))
|
| 219 |
+
blocksizeY, blocksizeX = min(blocksize, Ly), min(blocksize, Lx)
|
| 220 |
+
blocksizeY = np.int32(blocksizeY)
|
| 221 |
+
blocksizeX = np.int32(blocksizeX)
|
| 222 |
+
# tiles overlap by 10% tile size
|
| 223 |
+
ny = 1 if Ly <= blocksize else int(np.ceil(
|
| 224 |
+
(1. + 2 * tile_overlap) * Ly / blocksize))
|
| 225 |
+
nx = 1 if Lx <= blocksize else int(np.ceil(
|
| 226 |
+
(1. + 2 * tile_overlap) * Lx / blocksize))
|
| 227 |
+
ystart = np.linspace(0, Ly - blocksizeY, ny).astype(int)
|
| 228 |
+
xstart = np.linspace(0, Lx - blocksizeX, nx).astype(int)
|
| 229 |
+
ysub = []
|
| 230 |
+
xsub = []
|
| 231 |
+
for j in range(len(ystart)):
|
| 232 |
+
for i in range(len(xstart)):
|
| 233 |
+
ysub.append([ystart[j], ystart[j] + blocksizeY])
|
| 234 |
+
xsub.append([xstart[i], xstart[i] + blocksizeX])
|
| 235 |
+
|
| 236 |
+
x01_tiles_z = []
|
| 237 |
+
x99_tiles_z = []
|
| 238 |
+
for z in range(Lz):
|
| 239 |
+
IMG = np.zeros((len(ystart), len(xstart), blocksizeY, blocksizeX, nchan),
|
| 240 |
+
"float32")
|
| 241 |
+
k = 0
|
| 242 |
+
for j in range(len(ystart)):
|
| 243 |
+
for i in range(len(xstart)):
|
| 244 |
+
IMG[j, i] = img[z, ysub[k][0]:ysub[k][1], xsub[k][0]:xsub[k][1], :]
|
| 245 |
+
k += 1
|
| 246 |
+
x01_tiles = np.percentile(IMG, lower, axis=(-3, -2))
|
| 247 |
+
x99_tiles = np.percentile(IMG, upper, axis=(-3, -2))
|
| 248 |
+
|
| 249 |
+
# fill areas with small differences with neighboring squares
|
| 250 |
+
to_fill = np.zeros(x01_tiles.shape[:2], "bool")
|
| 251 |
+
for c in range(nchan):
|
| 252 |
+
to_fill = x99_tiles[:, :, c] - x01_tiles[:, :, c] < +1e-3
|
| 253 |
+
if to_fill.sum() > 0 and to_fill.sum() < x99_tiles[:, :, c].size:
|
| 254 |
+
fill_vals = np.nonzero(to_fill)
|
| 255 |
+
fill_neigh = np.nonzero(~to_fill)
|
| 256 |
+
nearest_neigh = (
|
| 257 |
+
(fill_vals[0] - fill_neigh[0][:, np.newaxis])**2 +
|
| 258 |
+
(fill_vals[1] - fill_neigh[1][:, np.newaxis])**2).argmin(axis=0)
|
| 259 |
+
x01_tiles[fill_vals[0], fill_vals[1],
|
| 260 |
+
c] = x01_tiles[fill_neigh[0][nearest_neigh],
|
| 261 |
+
fill_neigh[1][nearest_neigh], c]
|
| 262 |
+
x99_tiles[fill_vals[0], fill_vals[1],
|
| 263 |
+
c] = x99_tiles[fill_neigh[0][nearest_neigh],
|
| 264 |
+
fill_neigh[1][nearest_neigh], c]
|
| 265 |
+
elif to_fill.sum() > 0 and to_fill.sum() == x99_tiles[:, :, c].size:
|
| 266 |
+
x01_tiles[:, :, c] = 0
|
| 267 |
+
x99_tiles[:, :, c] = 1
|
| 268 |
+
x01_tiles_z.append(x01_tiles)
|
| 269 |
+
x99_tiles_z.append(x99_tiles)
|
| 270 |
+
|
| 271 |
+
x01_tiles_z = np.array(x01_tiles_z)
|
| 272 |
+
x99_tiles_z = np.array(x99_tiles_z)
|
| 273 |
+
# do not smooth over z-axis if not normalizing separately per plane
|
| 274 |
+
for a in range(2):
|
| 275 |
+
x01_tiles_z = gaussian_filter1d(x01_tiles_z, 1, axis=a)
|
| 276 |
+
x99_tiles_z = gaussian_filter1d(x99_tiles_z, 1, axis=a)
|
| 277 |
+
if norm3D:
|
| 278 |
+
smooth3D = 1 if smooth3D == 0 else smooth3D
|
| 279 |
+
x01_tiles_z = gaussian_filter1d(x01_tiles_z, smooth3D, axis=a)
|
| 280 |
+
x99_tiles_z = gaussian_filter1d(x99_tiles_z, smooth3D, axis=a)
|
| 281 |
+
|
| 282 |
+
if not norm3D and Lz > 1:
|
| 283 |
+
x01 = np.zeros((len(x01_tiles_z), Ly, Lx, nchan), "float32")
|
| 284 |
+
x99 = np.zeros((len(x01_tiles_z), Ly, Lx, nchan), "float32")
|
| 285 |
+
for z in range(Lz):
|
| 286 |
+
x01_rsz = cv2.resize(x01_tiles_z[z], (Lx, Ly),
|
| 287 |
+
interpolation=cv2.INTER_LINEAR)
|
| 288 |
+
x01[z] = x01_rsz[..., np.newaxis] if nchan == 1 else x01_rsz
|
| 289 |
+
x99_rsz = cv2.resize(x99_tiles_z[z], (Lx, Ly),
|
| 290 |
+
interpolation=cv2.INTER_LINEAR)
|
| 291 |
+
x99[z] = x99_rsz[..., np.newaxis] if nchan == 1 else x01_rsz
|
| 292 |
+
if (x99 - x01).min() < 1e-3:
|
| 293 |
+
raise ZeroDivisionError(
|
| 294 |
+
"cannot use norm3D=False with tile_norm, sample is too sparse; set norm3D=True or tile_norm=0"
|
| 295 |
+
)
|
| 296 |
+
else:
|
| 297 |
+
x01 = cv2.resize(x01_tiles_z.mean(axis=0), (Lx, Ly),
|
| 298 |
+
interpolation=cv2.INTER_LINEAR)
|
| 299 |
+
x99 = cv2.resize(x99_tiles_z.mean(axis=0), (Lx, Ly),
|
| 300 |
+
interpolation=cv2.INTER_LINEAR)
|
| 301 |
+
if x01.ndim < 3:
|
| 302 |
+
x01 = x01[..., np.newaxis]
|
| 303 |
+
x99 = x99[..., np.newaxis]
|
| 304 |
+
|
| 305 |
+
if is1c:
|
| 306 |
+
img, x01, x99 = img.squeeze(), x01.squeeze(), x99.squeeze()
|
| 307 |
+
elif not is3D:
|
| 308 |
+
img, x01, x99 = img[0], x01[0], x99[0]
|
| 309 |
+
|
| 310 |
+
# normalize
|
| 311 |
+
img -= x01
|
| 312 |
+
img /= (x99 - x01)
|
| 313 |
+
|
| 314 |
+
return img
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
def gaussian_kernel(sigma, Ly, Lx, device=torch.device("cpu")):
|
| 318 |
+
"""
|
| 319 |
+
Generates a 2D Gaussian kernel.
|
| 320 |
+
|
| 321 |
+
Args:
|
| 322 |
+
sigma (float): Standard deviation of the Gaussian distribution.
|
| 323 |
+
Ly (int): Number of pixels in the y-axis.
|
| 324 |
+
Lx (int): Number of pixels in the x-axis.
|
| 325 |
+
device (torch.device, optional): Device to store the kernel tensor. Defaults to torch.device("cpu").
|
| 326 |
+
|
| 327 |
+
Returns:
|
| 328 |
+
torch.Tensor: 2D Gaussian kernel tensor.
|
| 329 |
+
|
| 330 |
+
"""
|
| 331 |
+
y = torch.linspace(-Ly / 2, Ly / 2 + 1, Ly, device=device)
|
| 332 |
+
x = torch.linspace(-Ly / 2, Ly / 2 + 1, Lx, device=device)
|
| 333 |
+
y, x = torch.meshgrid(y, x, indexing="ij")
|
| 334 |
+
kernel = torch.exp(-(y**2 + x**2) / (2 * sigma**2))
|
| 335 |
+
kernel /= kernel.sum()
|
| 336 |
+
return kernel
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
def smooth_sharpen_img(img, smooth_radius=6, sharpen_radius=12,
|
| 340 |
+
device=torch.device("cpu"), is3D=False):
|
| 341 |
+
"""Sharpen blurry images with surround subtraction and/or smooth noisy images.
|
| 342 |
+
|
| 343 |
+
Args:
|
| 344 |
+
img (float32): Array that's (Lz x) Ly x Lx (x nchan).
|
| 345 |
+
smooth_radius (float, optional): Size of gaussian smoothing filter, recommended to be 1/10-1/4 of cell diameter
|
| 346 |
+
(if also sharpening, should be 2-3x smaller than sharpen_radius). Defaults to 6.
|
| 347 |
+
sharpen_radius (float, optional): Size of gaussian surround filter, recommended to be 1/8-1/2 of cell diameter
|
| 348 |
+
(if also smoothing, should be 2-3x larger than smooth_radius). Defaults to 12.
|
| 349 |
+
device (torch.device, optional): Device on which to perform sharpening.
|
| 350 |
+
Will be faster on GPU but need to ensure GPU has RAM for image. Defaults to torch.device("cpu").
|
| 351 |
+
is3D (bool, optional): If image is 3D stack (only necessary to set if img.ndim==3). Defaults to False.
|
| 352 |
+
|
| 353 |
+
Returns:
|
| 354 |
+
img_sharpen (float32): Array that's (Lz x) Ly x Lx (x nchan).
|
| 355 |
+
"""
|
| 356 |
+
img_sharpen = torch.from_numpy(img.astype("float32")).to(device)
|
| 357 |
+
shape = img_sharpen.shape
|
| 358 |
+
|
| 359 |
+
is1c = True if img_sharpen.ndim == 2 or (is3D and img_sharpen.ndim == 3) else False
|
| 360 |
+
is3D = True if img_sharpen.ndim > 3 or (is3D and img_sharpen.ndim == 3) else False
|
| 361 |
+
img_sharpen = img_sharpen.unsqueeze(-1) if is1c else img_sharpen
|
| 362 |
+
img_sharpen = img_sharpen.unsqueeze(0) if img_sharpen.ndim == 3 else img_sharpen
|
| 363 |
+
Lz, Ly, Lx, nchan = img_sharpen.shape
|
| 364 |
+
|
| 365 |
+
if smooth_radius > 0:
|
| 366 |
+
kernel = gaussian_kernel(smooth_radius, Ly, Lx, device=device)
|
| 367 |
+
if sharpen_radius > 0:
|
| 368 |
+
kernel += -1 * gaussian_kernel(sharpen_radius, Ly, Lx, device=device)
|
| 369 |
+
elif sharpen_radius > 0:
|
| 370 |
+
kernel = -1 * gaussian_kernel(sharpen_radius, Ly, Lx, device=device)
|
| 371 |
+
kernel[Ly // 2, Lx // 2] = 1
|
| 372 |
+
|
| 373 |
+
fhp = fft2(kernel)
|
| 374 |
+
for z in range(Lz):
|
| 375 |
+
for c in range(nchan):
|
| 376 |
+
img_filt = torch.real(ifft2(
|
| 377 |
+
fft2(img_sharpen[z, :, :, c]) * torch.conj(fhp)))
|
| 378 |
+
img_filt = fftshift(img_filt)
|
| 379 |
+
img_sharpen[z, :, :, c] = img_filt
|
| 380 |
+
|
| 381 |
+
img_sharpen = img_sharpen.reshape(shape)
|
| 382 |
+
return img_sharpen.cpu().numpy()
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def move_axis(img, m_axis=-1, first=True):
|
| 386 |
+
""" move axis m_axis to first or last position """
|
| 387 |
+
if m_axis == -1:
|
| 388 |
+
m_axis = img.ndim - 1
|
| 389 |
+
m_axis = min(img.ndim - 1, m_axis)
|
| 390 |
+
axes = np.arange(0, img.ndim)
|
| 391 |
+
if first:
|
| 392 |
+
axes[1:m_axis + 1] = axes[:m_axis]
|
| 393 |
+
axes[0] = m_axis
|
| 394 |
+
else:
|
| 395 |
+
axes[m_axis:-1] = axes[m_axis + 1:]
|
| 396 |
+
axes[-1] = m_axis
|
| 397 |
+
img = img.transpose(tuple(axes))
|
| 398 |
+
return img
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
def move_min_dim(img, force=False):
|
| 402 |
+
"""Move the minimum dimension last as channels if it is less than 10 or force is True.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
img (ndarray): The input image.
|
| 406 |
+
force (bool, optional): If True, the minimum dimension will always be moved.
|
| 407 |
+
Defaults to False.
|
| 408 |
+
|
| 409 |
+
Returns:
|
| 410 |
+
ndarray: The image with the minimum dimension moved to the last axis as channels.
|
| 411 |
+
"""
|
| 412 |
+
if len(img.shape) > 2:
|
| 413 |
+
min_dim = min(img.shape)
|
| 414 |
+
if min_dim < 10 or force:
|
| 415 |
+
if img.shape[-1] == min_dim:
|
| 416 |
+
channel_axis = -1
|
| 417 |
+
else:
|
| 418 |
+
channel_axis = (img.shape).index(min_dim)
|
| 419 |
+
img = move_axis(img, m_axis=channel_axis, first=False)
|
| 420 |
+
return img
|
| 421 |
+
|
| 422 |
+
|
| 423 |
+
def update_axis(m_axis, to_squeeze, ndim):
|
| 424 |
+
"""
|
| 425 |
+
Squeeze the axis value based on the given parameters.
|
| 426 |
+
|
| 427 |
+
Args:
|
| 428 |
+
m_axis (int): The current axis value.
|
| 429 |
+
to_squeeze (numpy.ndarray): An array of indices to squeeze.
|
| 430 |
+
ndim (int): The number of dimensions.
|
| 431 |
+
|
| 432 |
+
Returns:
|
| 433 |
+
m_axis (int or None): The updated axis value.
|
| 434 |
+
"""
|
| 435 |
+
if m_axis == -1:
|
| 436 |
+
m_axis = ndim - 1
|
| 437 |
+
if (to_squeeze == m_axis).sum() == 1:
|
| 438 |
+
m_axis = None
|
| 439 |
+
else:
|
| 440 |
+
inds = np.ones(ndim, bool)
|
| 441 |
+
inds[to_squeeze] = False
|
| 442 |
+
m_axis = np.nonzero(np.arange(0, ndim)[inds] == m_axis)[0]
|
| 443 |
+
if len(m_axis) > 0:
|
| 444 |
+
m_axis = m_axis[0]
|
| 445 |
+
else:
|
| 446 |
+
m_axis = None
|
| 447 |
+
return m_axis
|
| 448 |
+
|
| 449 |
+
|
| 450 |
+
def _convert_image_3d(x, channel_axis=None, z_axis=None):
|
| 451 |
+
"""
|
| 452 |
+
Convert a 3D or 4D image array to have dimensions ordered as (Z, X, Y, C).
|
| 453 |
+
|
| 454 |
+
Arrays of ndim=3 are assumed to be grayscale and must be specified with z_axis.
|
| 455 |
+
Arrays of ndim=4 must have both `channel_axis` and `z_axis` specified.
|
| 456 |
+
|
| 457 |
+
Args:
|
| 458 |
+
x (numpy.ndarray): Input image array. Must be either 3D (assumed to be grayscale 3D) or 4D.
|
| 459 |
+
channel_axis (int): The axis index corresponding to the channel dimension in the input array. \
|
| 460 |
+
Must be specified for 4D images.
|
| 461 |
+
z_axis (int): The axis index corresponding to the depth (Z) dimension in the input array. \
|
| 462 |
+
Must be specified for both 3D and 4D images.
|
| 463 |
+
|
| 464 |
+
Returns:
|
| 465 |
+
numpy.ndarray: A 4D image array with dimensions ordered as (Z, X, Y, C), where C is the channel
|
| 466 |
+
dimension. If the input has fewer than 3 channels, the output will be padded with zeros to \
|
| 467 |
+
have 3 channels. If the input has more than 3 channels, only the first 3 channels will be retained.
|
| 468 |
+
|
| 469 |
+
Raises:
|
| 470 |
+
ValueError: If `z_axis` is not specified for 3D images. If either `channel_axis` or `z_axis` \
|
| 471 |
+
is not specified for 4D images. If the input image does not have 3 or 4 dimensions.
|
| 472 |
+
|
| 473 |
+
Notes:
|
| 474 |
+
- For 3D images (ndim=3), the function assumes the input is grayscale and adds a singleton channel dimension.
|
| 475 |
+
- The function reorders the dimensions of the input array to ensure the output has the desired (Z, X, Y, C) order.
|
| 476 |
+
- If the number of channels is not equal to 3, the function either truncates or pads the \
|
| 477 |
+
channels to ensure the output has exactly 3 channels.
|
| 478 |
+
"""
|
| 479 |
+
|
| 480 |
+
if x.ndim < 3:
|
| 481 |
+
raise ValueError(f"Input image must have at least 3 dimensions, input shape: {x.shape}, ndim={x.ndim}")
|
| 482 |
+
|
| 483 |
+
if z_axis is not None and z_axis < 0:
|
| 484 |
+
z_axis += x.ndim
|
| 485 |
+
|
| 486 |
+
# if image is ndim==3, assume it is greyscale 3D and use provided z_axis
|
| 487 |
+
if x.ndim == 3 and z_axis is not None:
|
| 488 |
+
# add in channel axis
|
| 489 |
+
x = x[..., np.newaxis]
|
| 490 |
+
channel_axis = 3
|
| 491 |
+
elif x.ndim == 3 and z_axis is None:
|
| 492 |
+
raise ValueError("z_axis must be specified when segmenting 3D images of ndim=3")
|
| 493 |
+
|
| 494 |
+
|
| 495 |
+
if channel_axis is None or z_axis is None:
|
| 496 |
+
raise ValueError("For 4D images, both `channel_axis` and `z_axis` must be explicitly specified. Please provide values for both parameters.")
|
| 497 |
+
if channel_axis is not None and channel_axis < 0:
|
| 498 |
+
channel_axis += x.ndim
|
| 499 |
+
if channel_axis is None or channel_axis >= x.ndim:
|
| 500 |
+
raise IndexError(f"channel_axis {channel_axis} is out of bounds for input array with {x.ndim} dimensions")
|
| 501 |
+
assert x.ndim == 4, f"input image must have ndim == 4, ndim={x.ndim}"
|
| 502 |
+
|
| 503 |
+
x_dim_shapes = list(x.shape)
|
| 504 |
+
num_z_layers = x_dim_shapes[z_axis]
|
| 505 |
+
num_channels = x_dim_shapes[channel_axis]
|
| 506 |
+
x_xy_axes = [i for i in range(x.ndim)]
|
| 507 |
+
|
| 508 |
+
# need to remove the z and channels from the shapes:
|
| 509 |
+
# delete the one with the bigger index first
|
| 510 |
+
if z_axis > channel_axis:
|
| 511 |
+
del x_dim_shapes[z_axis]
|
| 512 |
+
del x_dim_shapes[channel_axis]
|
| 513 |
+
|
| 514 |
+
del x_xy_axes[z_axis]
|
| 515 |
+
del x_xy_axes[channel_axis]
|
| 516 |
+
|
| 517 |
+
else:
|
| 518 |
+
del x_dim_shapes[channel_axis]
|
| 519 |
+
del x_dim_shapes[z_axis]
|
| 520 |
+
|
| 521 |
+
del x_xy_axes[channel_axis]
|
| 522 |
+
del x_xy_axes[z_axis]
|
| 523 |
+
|
| 524 |
+
x = x.transpose((z_axis, x_xy_axes[0], x_xy_axes[1], channel_axis))
|
| 525 |
+
|
| 526 |
+
# Handle cases with not 3 channels:
|
| 527 |
+
if num_channels != 3:
|
| 528 |
+
x_chans_to_copy = min(3, num_channels)
|
| 529 |
+
|
| 530 |
+
if num_channels > 3:
|
| 531 |
+
transforms_logger.warning("more than 3 channels provided, only segmenting on first 3 channels")
|
| 532 |
+
x = x[..., :x_chans_to_copy]
|
| 533 |
+
else:
|
| 534 |
+
# less than 3 channels: pad up to
|
| 535 |
+
pad_width = [(0, 0), (0, 0), (0, 0), (0, 3 - x_chans_to_copy)]
|
| 536 |
+
x = np.pad(x, pad_width, mode='constant', constant_values=0)
|
| 537 |
+
|
| 538 |
+
return x
|
| 539 |
+
|
| 540 |
+
|
| 541 |
+
def convert_image(x, channel_axis=None, z_axis=None, do_3D=False):
|
| 542 |
+
"""Converts the image to have the z-axis first, channels last. Image will be converted to 3 channels if it is not already.
|
| 543 |
+
If more than 3 channels are provided, only the first 3 channels will be used.
|
| 544 |
+
|
| 545 |
+
Accepts:
|
| 546 |
+
- 2D images with no channel dimension: `z_axis` and `channel_axis` must be `None`
|
| 547 |
+
- 2D images with channel dimension: `channel_axis` will be guessed between first or last axis, can also specify `channel_axis`. `z_axis` must be `None`
|
| 548 |
+
- 3D images with or without channels:
|
| 549 |
+
|
| 550 |
+
Args:
|
| 551 |
+
x (numpy.ndarray or torch.Tensor): The input image.
|
| 552 |
+
channel_axis (int or None): The axis of the channels in the input image. If None, the axis is determined automatically.
|
| 553 |
+
z_axis (int or None): The axis of the z-dimension in the input image. If None, the axis is determined automatically.
|
| 554 |
+
do_3D (bool): Whether to process the image in 3D mode. Defaults to False.
|
| 555 |
+
|
| 556 |
+
Returns:
|
| 557 |
+
numpy.ndarray: The converted image.
|
| 558 |
+
|
| 559 |
+
Raises:
|
| 560 |
+
ValueError: If the input image is 2D and do_3D is True.
|
| 561 |
+
ValueError: If the input image is 4D and do_3D is False.
|
| 562 |
+
"""
|
| 563 |
+
|
| 564 |
+
# check if image is a torch array instead of numpy array, convert to numpy
|
| 565 |
+
ndim = x.ndim
|
| 566 |
+
if torch.is_tensor(x):
|
| 567 |
+
transforms_logger.warning("torch array used as input, converting to numpy")
|
| 568 |
+
x = x.cpu().numpy()
|
| 569 |
+
|
| 570 |
+
# should be 2D
|
| 571 |
+
if z_axis is not None and not do_3D:
|
| 572 |
+
raise ValueError("2D image provided, but z_axis is not None. Set z_axis=None to process 2D images of ndim=2 or 3.")
|
| 573 |
+
|
| 574 |
+
# make sure that channel_axis and z_axis are specified if 3D
|
| 575 |
+
if ndim == 4 and not do_3D:
|
| 576 |
+
raise ValueError("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images. ndims=4")
|
| 577 |
+
|
| 578 |
+
# make sure that channel_axis and z_axis are specified if 3D
|
| 579 |
+
if do_3D:
|
| 580 |
+
return _convert_image_3d(x, channel_axis=channel_axis, z_axis=z_axis)
|
| 581 |
+
|
| 582 |
+
######################## 2D reshaping ########################
|
| 583 |
+
# if user specifies channel axis, return early
|
| 584 |
+
if channel_axis is not None:
|
| 585 |
+
if ndim == 2:
|
| 586 |
+
raise ValueError("2D image provided, but channel_axis is not None. Set channel_axis=None to process 2D images of ndim=2.")
|
| 587 |
+
|
| 588 |
+
# Put channel axis last:
|
| 589 |
+
# Find the indices of the dims that need to be put in dim 0 and 1
|
| 590 |
+
n_channels = x.shape[channel_axis]
|
| 591 |
+
x_shape_dims = list(x.shape)
|
| 592 |
+
del x_shape_dims[channel_axis]
|
| 593 |
+
dimension_indicies = [i for i in range(x.ndim)]
|
| 594 |
+
del dimension_indicies[channel_axis]
|
| 595 |
+
|
| 596 |
+
x = x.transpose((dimension_indicies[0], dimension_indicies[1], channel_axis))
|
| 597 |
+
|
| 598 |
+
if n_channels != 3:
|
| 599 |
+
x_chans_to_copy = min(3, n_channels)
|
| 600 |
+
|
| 601 |
+
if n_channels > 3:
|
| 602 |
+
transforms_logger.warning("more than 3 channels provided, only segmenting on first 3 channels")
|
| 603 |
+
x = x[..., :x_chans_to_copy]
|
| 604 |
+
else:
|
| 605 |
+
x_out = np.zeros((x_shape_dims[0], x_shape_dims[1], 3), dtype=x.dtype)
|
| 606 |
+
x_out[..., :x_chans_to_copy] = x[...]
|
| 607 |
+
x = x_out
|
| 608 |
+
del x_out
|
| 609 |
+
|
| 610 |
+
return x
|
| 611 |
+
|
| 612 |
+
# do image padding and channel conversion
|
| 613 |
+
if ndim == 2:
|
| 614 |
+
# grayscale image, make 3 channels
|
| 615 |
+
x_out = np.zeros((x.shape[0], x.shape[1], 3), dtype=x.dtype)
|
| 616 |
+
x_out[..., 0] = x
|
| 617 |
+
x = x_out
|
| 618 |
+
del x_out
|
| 619 |
+
elif ndim == 3:
|
| 620 |
+
# assume 2d with channels
|
| 621 |
+
# find dim with smaller size between first and last dims
|
| 622 |
+
move_channel_axis = x.shape[0] < x.shape[2]
|
| 623 |
+
if move_channel_axis:
|
| 624 |
+
x = x.transpose((1, 2, 0))
|
| 625 |
+
|
| 626 |
+
# zero padding up to 3 channels:
|
| 627 |
+
num_channels = x.shape[-1]
|
| 628 |
+
if num_channels > 3:
|
| 629 |
+
transforms_logger.warning("Found more than 3 channels, only using first 3")
|
| 630 |
+
num_channels = 3
|
| 631 |
+
x_out = np.zeros((x.shape[0], x.shape[1], 3), dtype=x.dtype)
|
| 632 |
+
x_out[..., :num_channels] = x[..., :num_channels]
|
| 633 |
+
x = x_out
|
| 634 |
+
del x_out
|
| 635 |
+
else:
|
| 636 |
+
# something is wrong: yell
|
| 637 |
+
expected_shapes = "2D (H, W), 3D (H, W, C), or 4D (Z, H, W, C)"
|
| 638 |
+
transforms_logger.critical(f"ERROR: Unexpected image shape: {str(x.shape)}. Expected shapes: {expected_shapes}")
|
| 639 |
+
raise ValueError(f"ERROR: Unexpected image shape: {str(x.shape)}. Expected shapes: {expected_shapes}")
|
| 640 |
+
|
| 641 |
+
return x
|
| 642 |
+
|
| 643 |
+
|
| 644 |
+
def normalize_img(img, normalize=True, norm3D=True, invert=False, lowhigh=None,
|
| 645 |
+
percentile=(1., 99.), sharpen_radius=0, smooth_radius=0,
|
| 646 |
+
tile_norm_blocksize=0, tile_norm_smooth3D=1, axis=-1):
|
| 647 |
+
"""Normalize each channel of the image with optional inversion, smoothing, and sharpening.
|
| 648 |
+
|
| 649 |
+
Args:
|
| 650 |
+
img (ndarray): The input image. It should have at least 3 dimensions.
|
| 651 |
+
If it is 4-dimensional, it assumes the first non-channel axis is the Z dimension.
|
| 652 |
+
normalize (bool, optional): Whether to perform normalization. Defaults to True.
|
| 653 |
+
norm3D (bool, optional): Whether to normalize in 3D. If True, the entire 3D stack will
|
| 654 |
+
be normalized per channel. If False, normalization is applied per Z-slice. Defaults to False.
|
| 655 |
+
invert (bool, optional): Whether to invert the image. Useful if cells are dark instead of bright.
|
| 656 |
+
Defaults to False.
|
| 657 |
+
lowhigh (tuple or ndarray, optional): The lower and upper bounds for normalization.
|
| 658 |
+
Can be a tuple of two values (applied to all channels) or an array of shape (nchan, 2)
|
| 659 |
+
for per-channel normalization. Incompatible with smoothing and sharpening.
|
| 660 |
+
Defaults to None.
|
| 661 |
+
percentile (tuple, optional): The lower and upper percentiles for normalization. If provided, it should be
|
| 662 |
+
a tuple of two values. Each value should be between 0 and 100. Defaults to (1.0, 99.0).
|
| 663 |
+
sharpen_radius (int, optional): The radius for sharpening the image. Defaults to 0.
|
| 664 |
+
smooth_radius (int, optional): The radius for smoothing the image. Defaults to 0.
|
| 665 |
+
tile_norm_blocksize (int, optional): The block size for tile-based normalization. Defaults to 0.
|
| 666 |
+
tile_norm_smooth3D (int, optional): The smoothness factor for tile-based normalization in 3D. Defaults to 1.
|
| 667 |
+
axis (int, optional): The channel axis to loop over for normalization. Defaults to -1.
|
| 668 |
+
|
| 669 |
+
Returns:
|
| 670 |
+
ndarray: The normalized image of the same size.
|
| 671 |
+
|
| 672 |
+
Raises:
|
| 673 |
+
ValueError: If the image has less than 3 dimensions.
|
| 674 |
+
ValueError: If the provided lowhigh or percentile values are invalid.
|
| 675 |
+
ValueError: If the image is inverted without normalization.
|
| 676 |
+
|
| 677 |
+
"""
|
| 678 |
+
if img.ndim < 3:
|
| 679 |
+
error_message = "Image needs to have at least 3 dimensions"
|
| 680 |
+
transforms_logger.critical(error_message)
|
| 681 |
+
raise ValueError(error_message)
|
| 682 |
+
|
| 683 |
+
img_norm = img if img.dtype=="float32" else img.astype(np.float32)
|
| 684 |
+
if axis != -1 and axis != img_norm.ndim - 1:
|
| 685 |
+
img_norm = np.moveaxis(img_norm, axis, -1) # Move channel axis to last
|
| 686 |
+
|
| 687 |
+
nchan = img_norm.shape[-1]
|
| 688 |
+
|
| 689 |
+
# Validate and handle lowhigh bounds
|
| 690 |
+
if lowhigh is not None:
|
| 691 |
+
lowhigh = np.array(lowhigh)
|
| 692 |
+
if lowhigh.shape == (2,):
|
| 693 |
+
lowhigh = np.tile(lowhigh, (nchan, 1)) # Expand to per-channel bounds
|
| 694 |
+
elif lowhigh.shape != (nchan, 2):
|
| 695 |
+
error_message = "`lowhigh` must have shape (2,) or (nchan, 2)"
|
| 696 |
+
transforms_logger.critical(error_message)
|
| 697 |
+
raise ValueError(error_message)
|
| 698 |
+
|
| 699 |
+
# Validate percentile
|
| 700 |
+
if percentile is None:
|
| 701 |
+
percentile = (1.0, 99.0)
|
| 702 |
+
elif not (0 <= percentile[0] < percentile[1] <= 100):
|
| 703 |
+
error_message = "Invalid percentile range, should be between 0 and 100"
|
| 704 |
+
transforms_logger.critical(error_message)
|
| 705 |
+
raise ValueError(error_message)
|
| 706 |
+
|
| 707 |
+
# Apply normalization based on lowhigh or percentile
|
| 708 |
+
cgood = np.zeros(nchan, "bool")
|
| 709 |
+
if lowhigh is not None:
|
| 710 |
+
for c in range(nchan):
|
| 711 |
+
lower = lowhigh[c, 0]
|
| 712 |
+
upper = lowhigh[c, 1]
|
| 713 |
+
img_norm[..., c] -= lower
|
| 714 |
+
img_norm[..., c] /= (upper - lower)
|
| 715 |
+
cgood[c] = True
|
| 716 |
+
else:
|
| 717 |
+
# Apply sharpening and smoothing if specified
|
| 718 |
+
if sharpen_radius > 0 or smooth_radius > 0:
|
| 719 |
+
img_norm = smooth_sharpen_img(
|
| 720 |
+
img_norm, sharpen_radius=sharpen_radius, smooth_radius=smooth_radius
|
| 721 |
+
)
|
| 722 |
+
|
| 723 |
+
# Apply tile-based normalization or standard normalization
|
| 724 |
+
if tile_norm_blocksize > 0:
|
| 725 |
+
img_norm = normalize99_tile(
|
| 726 |
+
img_norm,
|
| 727 |
+
blocksize=tile_norm_blocksize,
|
| 728 |
+
lower=percentile[0],
|
| 729 |
+
upper=percentile[1],
|
| 730 |
+
smooth3D=tile_norm_smooth3D,
|
| 731 |
+
norm3D=norm3D,
|
| 732 |
+
)
|
| 733 |
+
cgood[:] = True
|
| 734 |
+
elif normalize:
|
| 735 |
+
if img_norm.ndim == 3 or norm3D: # i.e. if YXC, or ZYXC with norm3D=True
|
| 736 |
+
for c in range(nchan):
|
| 737 |
+
if np.ptp(img_norm[..., c]) > 0.:
|
| 738 |
+
img_norm[..., c] = normalize99(
|
| 739 |
+
img_norm[..., c],
|
| 740 |
+
lower=percentile[0],
|
| 741 |
+
upper=percentile[1],
|
| 742 |
+
copy=False, downsample=True,
|
| 743 |
+
)
|
| 744 |
+
cgood[c] = True
|
| 745 |
+
else: # i.e. if ZYXC with norm3D=False then per Z-slice
|
| 746 |
+
for z in range(img_norm.shape[0]):
|
| 747 |
+
for c in range(nchan):
|
| 748 |
+
if np.ptp(img_norm[z, ..., c]) > 0.:
|
| 749 |
+
img_norm[z, ..., c] = normalize99(
|
| 750 |
+
img_norm[z, ..., c],
|
| 751 |
+
lower=percentile[0],
|
| 752 |
+
upper=percentile[1],
|
| 753 |
+
copy=False, downsample=True,
|
| 754 |
+
)
|
| 755 |
+
cgood[c] = True
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
if invert:
|
| 759 |
+
if lowhigh is not None or tile_norm_blocksize > 0 or normalize:
|
| 760 |
+
for c in range(nchan):
|
| 761 |
+
if cgood[c]:
|
| 762 |
+
img_norm[..., c] = 1 - img_norm[..., c]
|
| 763 |
+
else:
|
| 764 |
+
error_message = "Cannot invert image without normalization"
|
| 765 |
+
transforms_logger.critical(error_message)
|
| 766 |
+
raise ValueError(error_message)
|
| 767 |
+
|
| 768 |
+
# Move channel axis back to the original position
|
| 769 |
+
if axis != -1 and axis != img_norm.ndim - 1:
|
| 770 |
+
img_norm = np.moveaxis(img_norm, -1, axis)
|
| 771 |
+
|
| 772 |
+
# The transformer can get confused if a channel is all 1's instead of all 0's:
|
| 773 |
+
for i, chan_did_normalize in enumerate(cgood):
|
| 774 |
+
if not chan_did_normalize:
|
| 775 |
+
if img_norm.ndim == 3:
|
| 776 |
+
img_norm[:, :, i] = 0
|
| 777 |
+
if img_norm.ndim == 4:
|
| 778 |
+
img_norm[:, :, :, i] = 0
|
| 779 |
+
|
| 780 |
+
return img_norm
|
| 781 |
+
|
| 782 |
+
def resize_safe(img, Ly, Lx, interpolation=cv2.INTER_LINEAR):
|
| 783 |
+
"""OpenCV resize function does not support uint32.
|
| 784 |
+
|
| 785 |
+
This function converts the image to float32 before resizing and then converts it back to uint32. Not safe!
|
| 786 |
+
References issue: https://github.com/MouseLand/cellpose/issues/937
|
| 787 |
+
|
| 788 |
+
Implications:
|
| 789 |
+
* Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not
|
| 790 |
+
a big issue. A 10,000x10,000 image takes 0.47s instead of 0.016s to cast and resize on 32 cores on GPU.
|
| 791 |
+
* Memory: However, memory usage increases. Not tested by how much.
|
| 792 |
+
|
| 793 |
+
Args:
|
| 794 |
+
img (ndarray): Image of size [Ly x Lx].
|
| 795 |
+
Ly (int): Desired height of the resized image.
|
| 796 |
+
Lx (int): Desired width of the resized image.
|
| 797 |
+
interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR.
|
| 798 |
+
|
| 799 |
+
Returns:
|
| 800 |
+
ndarray: Resized image of size [Ly x Lx].
|
| 801 |
+
|
| 802 |
+
"""
|
| 803 |
+
|
| 804 |
+
# cast image
|
| 805 |
+
cast = img.dtype == np.uint32
|
| 806 |
+
if cast:
|
| 807 |
+
img = img.astype(np.float32)
|
| 808 |
+
|
| 809 |
+
# resize
|
| 810 |
+
img = cv2.resize(img, (Lx, Ly), interpolation=interpolation)
|
| 811 |
+
|
| 812 |
+
# cast back
|
| 813 |
+
if cast:
|
| 814 |
+
img = img.round().astype(np.uint32)
|
| 815 |
+
|
| 816 |
+
return img
|
| 817 |
+
|
| 818 |
+
|
| 819 |
+
def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEAR,
|
| 820 |
+
no_channels=False):
|
| 821 |
+
"""Resize image for computing flows / unresize for computing dynamics.
|
| 822 |
+
|
| 823 |
+
Args:
|
| 824 |
+
img0 (ndarray): Image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X].
|
| 825 |
+
Ly (int, optional): Desired height of the resized image. Defaults to None.
|
| 826 |
+
Lx (int, optional): Desired width of the resized image. Defaults to None.
|
| 827 |
+
rsz (float, optional): Resize coefficient(s) for the image. If Ly is None, rsz is used. Defaults to None.
|
| 828 |
+
interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR.
|
| 829 |
+
no_channels (bool, optional): Flag indicating whether to treat the third dimension as a channel.
|
| 830 |
+
Defaults to False.
|
| 831 |
+
|
| 832 |
+
Returns:
|
| 833 |
+
ndarray: Resized image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan].
|
| 834 |
+
|
| 835 |
+
Raises:
|
| 836 |
+
ValueError: If Ly is None and rsz is None.
|
| 837 |
+
|
| 838 |
+
"""
|
| 839 |
+
if Ly is None and rsz is None:
|
| 840 |
+
error_message = "must give size to resize to or factor to use for resizing"
|
| 841 |
+
transforms_logger.critical(error_message)
|
| 842 |
+
raise ValueError(error_message)
|
| 843 |
+
|
| 844 |
+
if Ly is None:
|
| 845 |
+
# determine Ly and Lx using rsz
|
| 846 |
+
if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
|
| 847 |
+
rsz = [rsz, rsz]
|
| 848 |
+
if no_channels:
|
| 849 |
+
Ly = int(img0.shape[-2] * rsz[-2])
|
| 850 |
+
Lx = int(img0.shape[-1] * rsz[-1])
|
| 851 |
+
else:
|
| 852 |
+
Ly = int(img0.shape[-3] * rsz[-2])
|
| 853 |
+
Lx = int(img0.shape[-2] * rsz[-1])
|
| 854 |
+
|
| 855 |
+
# no_channels useful for z-stacks, so the third dimension is not treated as a channel
|
| 856 |
+
# but if this is called for grayscale images, they first become [Ly,Lx,2] so ndim=3 but
|
| 857 |
+
if (img0.ndim > 2 and no_channels) or (img0.ndim == 4 and not no_channels):
|
| 858 |
+
if Ly == 0 or Lx == 0:
|
| 859 |
+
raise ValueError(
|
| 860 |
+
"anisotropy too high / low -- not enough pixels to resize to ratio")
|
| 861 |
+
for i, img in enumerate(img0):
|
| 862 |
+
imgi = resize_safe(img, Ly, Lx, interpolation=interpolation)
|
| 863 |
+
if i==0:
|
| 864 |
+
if no_channels:
|
| 865 |
+
imgs = np.zeros((img0.shape[0], Ly, Lx), imgi.dtype)
|
| 866 |
+
else:
|
| 867 |
+
imgs = np.zeros((img0.shape[0], Ly, Lx, img0.shape[-1]), imgi.dtype)
|
| 868 |
+
imgs[i] = imgi if imgi.ndim > 2 or no_channels else imgi[..., np.newaxis]
|
| 869 |
+
else:
|
| 870 |
+
imgs = resize_safe(img0, Ly, Lx, interpolation=interpolation)
|
| 871 |
+
return imgs
|
| 872 |
+
|
| 873 |
+
def get_pad_yx(Ly, Lx, div=16, extra=1, min_size=None):
|
| 874 |
+
if min_size is None or Ly >= min_size[-2]:
|
| 875 |
+
Lpad = int(div * np.ceil(Ly / div) - Ly)
|
| 876 |
+
else:
|
| 877 |
+
Lpad = min_size[-2] - Ly
|
| 878 |
+
ypad1 = extra * div // 2 + Lpad // 2
|
| 879 |
+
ypad2 = extra * div // 2 + Lpad - Lpad // 2
|
| 880 |
+
if min_size is None or Lx >= min_size[-1]:
|
| 881 |
+
Lpad = int(div * np.ceil(Lx / div) - Lx)
|
| 882 |
+
else:
|
| 883 |
+
Lpad = min_size[-1] - Lx
|
| 884 |
+
xpad1 = extra * div // 2 + Lpad // 2
|
| 885 |
+
xpad2 = extra * div // 2 + Lpad - Lpad // 2
|
| 886 |
+
|
| 887 |
+
return ypad1, ypad2, xpad1, xpad2
|
| 888 |
+
|
| 889 |
+
|
| 890 |
+
def pad_image_ND(img0, div=16, extra=1, min_size=None, zpad=False):
|
| 891 |
+
"""Pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D).
|
| 892 |
+
|
| 893 |
+
Args:
|
| 894 |
+
img0 (ndarray): Image of size [nchan (x Lz) x Ly x Lx].
|
| 895 |
+
div (int, optional): Divisor for padding. Defaults to 16.
|
| 896 |
+
extra (int, optional): Extra padding. Defaults to 1.
|
| 897 |
+
min_size (tuple, optional): Minimum size of the image. Defaults to None.
|
| 898 |
+
|
| 899 |
+
Returns:
|
| 900 |
+
A tuple containing (I, ysub, xsub) or (I, ysub, xsub, zsub), I is padded image, -sub are ranges of pixels in the padded image corresponding to img0.
|
| 901 |
+
|
| 902 |
+
"""
|
| 903 |
+
Ly, Lx = img0.shape[-2:]
|
| 904 |
+
ypad1, ypad2, xpad1, xpad2 = get_pad_yx(Ly, Lx, div=div, extra=extra, min_size=min_size)
|
| 905 |
+
|
| 906 |
+
if img0.ndim > 3:
|
| 907 |
+
if zpad:
|
| 908 |
+
Lpad = int(div * np.ceil(img0.shape[-3] / div) - img0.shape[-3])
|
| 909 |
+
zpad1 = extra * div // 2 + Lpad // 2
|
| 910 |
+
zpad2 = extra * div // 2 + Lpad - Lpad // 2
|
| 911 |
+
else:
|
| 912 |
+
zpad1, zpad2 = 0, 0
|
| 913 |
+
pads = np.array([[0, 0], [zpad1, zpad2], [ypad1, ypad2], [xpad1, xpad2]])
|
| 914 |
+
else:
|
| 915 |
+
pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]])
|
| 916 |
+
|
| 917 |
+
I = np.pad(img0, pads, mode="constant")
|
| 918 |
+
|
| 919 |
+
ysub = np.arange(ypad1, ypad1 + Ly)
|
| 920 |
+
xsub = np.arange(xpad1, xpad1 + Lx)
|
| 921 |
+
if zpad:
|
| 922 |
+
zsub = np.arange(zpad1, zpad1 + img0.shape[-3])
|
| 923 |
+
return I, ysub, xsub, zsub
|
| 924 |
+
else:
|
| 925 |
+
return I, ysub, xsub
|
| 926 |
+
|
| 927 |
+
|
| 928 |
+
def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=False,
|
| 929 |
+
zcrop=48, do_flip=True, rotate=True, rescale=None, unet=False,
|
| 930 |
+
random_per_image=True):
|
| 931 |
+
"""Augmentation by random rotation and resizing.
|
| 932 |
+
|
| 933 |
+
Args:
|
| 934 |
+
X (list of ND-arrays, float): List of image arrays of size [nchan x Ly x Lx] or [Ly x Lx].
|
| 935 |
+
Y (list of ND-arrays, float, optional): List of image labels of size [nlabels x Ly x Lx] or [Ly x Lx].
|
| 936 |
+
The 1st channel of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation).
|
| 937 |
+
If Y.shape[0]==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow].
|
| 938 |
+
If unet, second channel is dist_to_bound. Defaults to None.
|
| 939 |
+
scale_range (float, optional): Range of resizing of images for augmentation.
|
| 940 |
+
Images are resized by (1-scale_range/2) + scale_range * np.random.rand(). Defaults to 1.0.
|
| 941 |
+
xy (tuple, int, optional): Size of transformed images to return. Defaults to (224,224).
|
| 942 |
+
do_flip (bool, optional): Whether or not to flip images horizontally. Defaults to True.
|
| 943 |
+
rotate (bool, optional): Whether or not to rotate images. Defaults to True.
|
| 944 |
+
rescale (array, float, optional): How much to resize images by before performing augmentations. Defaults to None.
|
| 945 |
+
unet (bool, optional): Whether or not to use unet. Defaults to False.
|
| 946 |
+
random_per_image (bool, optional): Different random rotate and resize per image. Defaults to True.
|
| 947 |
+
|
| 948 |
+
Returns:
|
| 949 |
+
A tuple containing (imgi, lbl, scale): imgi (ND-array, float): Transformed images in array [nimg x nchan x xy[0] x xy[1]];
|
| 950 |
+
lbl (ND-array, float): Transformed labels in array [nimg x nchan x xy[0] x xy[1]];
|
| 951 |
+
scale (array, float): Amount each image was resized by.
|
| 952 |
+
"""
|
| 953 |
+
scale_range = max(0, min(2, float(scale_range))) if scale_range is not None else scale_range
|
| 954 |
+
nimg = len(X)
|
| 955 |
+
if X[0].ndim > 2:
|
| 956 |
+
nchan = X[0].shape[0]
|
| 957 |
+
else:
|
| 958 |
+
nchan = 1
|
| 959 |
+
if do_3D and X[0].ndim > 3:
|
| 960 |
+
shape = (zcrop, xy[0], xy[1])
|
| 961 |
+
else:
|
| 962 |
+
shape = (xy[0], xy[1])
|
| 963 |
+
imgi = np.zeros((nimg, nchan, *shape), "float32")
|
| 964 |
+
|
| 965 |
+
lbl = []
|
| 966 |
+
if Y is not None:
|
| 967 |
+
if Y[0].ndim > 2:
|
| 968 |
+
nt = Y[0].shape[0]
|
| 969 |
+
else:
|
| 970 |
+
nt = 1
|
| 971 |
+
lbl = np.zeros((nimg, nt, *shape), np.float32)
|
| 972 |
+
|
| 973 |
+
scale = np.ones(nimg, np.float32)
|
| 974 |
+
|
| 975 |
+
for n in range(nimg):
|
| 976 |
+
|
| 977 |
+
if random_per_image or n == 0:
|
| 978 |
+
Ly, Lx = X[n].shape[-2:]
|
| 979 |
+
# generate random augmentation parameters
|
| 980 |
+
flip = np.random.rand() > .5
|
| 981 |
+
theta = np.random.rand() * np.pi * 2 if rotate else 0.
|
| 982 |
+
if scale_range is None:
|
| 983 |
+
scale[n] = 2 ** (4 * np.random.rand() - 2)
|
| 984 |
+
else:
|
| 985 |
+
scale[n] = (1 - scale_range / 2) + scale_range * np.random.rand()
|
| 986 |
+
if rescale is not None:
|
| 987 |
+
scale[n] *= 1. / rescale[n]
|
| 988 |
+
dxy = np.maximum(0, np.array([Lx * scale[n] - xy[1],
|
| 989 |
+
Ly * scale[n] - xy[0]]))
|
| 990 |
+
dxy = (np.random.rand(2,) - .5) * dxy
|
| 991 |
+
|
| 992 |
+
# create affine transform
|
| 993 |
+
cc = np.array([Lx / 2, Ly / 2])
|
| 994 |
+
cc1 = cc - np.array([Lx - xy[1], Ly - xy[0]]) / 2 + dxy
|
| 995 |
+
pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
|
| 996 |
+
pts2 = np.float32([
|
| 997 |
+
cc1,
|
| 998 |
+
cc1 + scale[n] * np.array([np.cos(theta), np.sin(theta)]),
|
| 999 |
+
cc1 + scale[n] *
|
| 1000 |
+
np.array([np.cos(np.pi / 2 + theta),
|
| 1001 |
+
np.sin(np.pi / 2 + theta)])
|
| 1002 |
+
])
|
| 1003 |
+
M = cv2.getAffineTransform(pts1, pts2)
|
| 1004 |
+
|
| 1005 |
+
img = X[n].copy()
|
| 1006 |
+
if Y is not None:
|
| 1007 |
+
labels = Y[n].copy()
|
| 1008 |
+
if labels.ndim < 3:
|
| 1009 |
+
labels = labels[np.newaxis, :, :]
|
| 1010 |
+
|
| 1011 |
+
if do_3D:
|
| 1012 |
+
Lz = X[n].shape[-3]
|
| 1013 |
+
flip_z = np.random.rand() > .5
|
| 1014 |
+
lz = int(np.round(zcrop / scale[n]))
|
| 1015 |
+
iz = np.random.randint(0, Lz - lz)
|
| 1016 |
+
img = img[:,iz:iz + lz,:,:]
|
| 1017 |
+
if Y is not None:
|
| 1018 |
+
labels = labels[:,iz:iz + lz,:,:]
|
| 1019 |
+
|
| 1020 |
+
if do_flip:
|
| 1021 |
+
if flip:
|
| 1022 |
+
img = img[..., ::-1]
|
| 1023 |
+
if Y is not None:
|
| 1024 |
+
labels = labels[..., ::-1]
|
| 1025 |
+
if nt > 1 and not unet:
|
| 1026 |
+
labels[-1] = -labels[-1]
|
| 1027 |
+
if do_3D and flip_z:
|
| 1028 |
+
img = img[:, ::-1]
|
| 1029 |
+
if Y is not None:
|
| 1030 |
+
labels = labels[:,::-1]
|
| 1031 |
+
if nt > 1 and not unet:
|
| 1032 |
+
labels[-3] = -labels[-3]
|
| 1033 |
+
|
| 1034 |
+
for k in range(nchan):
|
| 1035 |
+
if do_3D:
|
| 1036 |
+
img0 = np.zeros((lz, xy[0], xy[1]), "float32")
|
| 1037 |
+
for z in range(lz):
|
| 1038 |
+
I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]),
|
| 1039 |
+
flags=cv2.INTER_LINEAR)
|
| 1040 |
+
img0[z] = I
|
| 1041 |
+
if scale[n] != 1.0:
|
| 1042 |
+
for y in range(imgi.shape[-2]):
|
| 1043 |
+
imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop),
|
| 1044 |
+
interpolation=cv2.INTER_LINEAR)
|
| 1045 |
+
else:
|
| 1046 |
+
imgi[n, k] = img0
|
| 1047 |
+
else:
|
| 1048 |
+
I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
|
| 1049 |
+
imgi[n, k] = I
|
| 1050 |
+
|
| 1051 |
+
if Y is not None:
|
| 1052 |
+
for k in range(nt):
|
| 1053 |
+
flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
|
| 1054 |
+
if do_3D:
|
| 1055 |
+
lbl0 = np.zeros((lz, xy[0], xy[1]), "float32")
|
| 1056 |
+
for z in range(lz):
|
| 1057 |
+
I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]),
|
| 1058 |
+
flags=flag)
|
| 1059 |
+
lbl0[z] = I
|
| 1060 |
+
if scale[n] != 1.0:
|
| 1061 |
+
for y in range(lbl.shape[-2]):
|
| 1062 |
+
lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop),
|
| 1063 |
+
interpolation=flag)
|
| 1064 |
+
else:
|
| 1065 |
+
lbl[n, k] = lbl0
|
| 1066 |
+
else:
|
| 1067 |
+
lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
|
| 1068 |
+
|
| 1069 |
+
if nt > 1 and not unet:
|
| 1070 |
+
v1 = lbl[n, -1].copy()
|
| 1071 |
+
v2 = lbl[n, -2].copy()
|
| 1072 |
+
lbl[n, -2] = (-v1 * np.sin(-theta) + v2 * np.cos(-theta))
|
| 1073 |
+
lbl[n, -1] = (v1 * np.cos(-theta) + v2 * np.sin(-theta))
|
| 1074 |
+
|
| 1075 |
+
return imgi, lbl, scale
|
| 1076 |
+
|
| 1077 |
+
|
| 1078 |
+
def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=(224, 224), do_3D=False,
|
| 1079 |
+
zcrop=48, do_flip=True, rotate=True, rescale=None, unet=False,
|
| 1080 |
+
random_per_image=True):
|
| 1081 |
+
"""Augmentation by random rotation and resizing.
|
| 1082 |
+
|
| 1083 |
+
Args:
|
| 1084 |
+
X (list of ND-arrays, float): List of image arrays of size [nchan x Ly x Lx] or [Ly x Lx].
|
| 1085 |
+
Y (list of ND-arrays, float, optional): List of image labels of size [nlabels x Ly x Lx] or [Ly x Lx].
|
| 1086 |
+
The 1st channel of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation).
|
| 1087 |
+
If Y.shape[0]==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow].
|
| 1088 |
+
If unet, second channel is dist_to_bound. Defaults to None.
|
| 1089 |
+
scale_range (float, optional): Range of resizing of images for augmentation.
|
| 1090 |
+
Images are resized by (1-scale_range/2) + scale_range * np.random.rand(). Defaults to 1.0.
|
| 1091 |
+
xy (tuple, int, optional): Size of transformed images to return. Defaults to (224,224).
|
| 1092 |
+
do_flip (bool, optional): Whether or not to flip images horizontally. Defaults to True.
|
| 1093 |
+
rotate (bool, optional): Whether or not to rotate images. Defaults to True.
|
| 1094 |
+
rescale (array, float, optional): How much to resize images by before performing augmentations. Defaults to None.
|
| 1095 |
+
unet (bool, optional): Whether or not to use unet. Defaults to False.
|
| 1096 |
+
random_per_image (bool, optional): Different random rotate and resize per image. Defaults to True.
|
| 1097 |
+
|
| 1098 |
+
Returns:
|
| 1099 |
+
A tuple containing (imgi, lbl, scale): imgi (ND-array, float): Transformed images in array [nimg x nchan x xy[0] x xy[1]];
|
| 1100 |
+
lbl (ND-array, float): Transformed labels in array [nimg x nchan x xy[0] x xy[1]];
|
| 1101 |
+
scale (array, float): Amount each image was resized by.
|
| 1102 |
+
"""
|
| 1103 |
+
scale_range = max(0, min(2, float(scale_range))) if scale_range is not None else scale_range
|
| 1104 |
+
nimg = len(X)
|
| 1105 |
+
if X[0].ndim > 2:
|
| 1106 |
+
nchan = X[0].shape[0]
|
| 1107 |
+
else:
|
| 1108 |
+
nchan = 1
|
| 1109 |
+
if do_3D and X[0].ndim > 3:
|
| 1110 |
+
shape = (zcrop, xy[0], xy[1])
|
| 1111 |
+
else:
|
| 1112 |
+
shape = (xy[0], xy[1])
|
| 1113 |
+
imgi = np.zeros((nimg, nchan, *shape), "float32")
|
| 1114 |
+
|
| 1115 |
+
lbl = []
|
| 1116 |
+
if Y is not None:
|
| 1117 |
+
if Y[0].ndim > 2:
|
| 1118 |
+
nt = Y[0].shape[0]
|
| 1119 |
+
else:
|
| 1120 |
+
nt = 1
|
| 1121 |
+
lbl = np.zeros((nimg, nt, *shape), np.float32)
|
| 1122 |
+
|
| 1123 |
+
if feat is not None:
|
| 1124 |
+
if feat[0].ndim > 2:
|
| 1125 |
+
nf = feat[0].shape[0]
|
| 1126 |
+
else:
|
| 1127 |
+
nf = 1
|
| 1128 |
+
feat_out = np.zeros((nimg, nf, *shape), "float32")
|
| 1129 |
+
|
| 1130 |
+
scale = np.ones(nimg, np.float32)
|
| 1131 |
+
|
| 1132 |
+
for n in range(nimg):
|
| 1133 |
+
|
| 1134 |
+
if random_per_image or n == 0:
|
| 1135 |
+
Ly, Lx = X[n].shape[-2:]
|
| 1136 |
+
# generate random augmentation parameters
|
| 1137 |
+
flip = np.random.rand() > .5
|
| 1138 |
+
theta = np.random.rand() * np.pi * 2 if rotate else 0.
|
| 1139 |
+
if scale_range is None:
|
| 1140 |
+
scale[n] = 2 ** (4 * np.random.rand() - 2)
|
| 1141 |
+
else:
|
| 1142 |
+
scale[n] = (1 - scale_range / 2) + scale_range * np.random.rand()
|
| 1143 |
+
if rescale is not None:
|
| 1144 |
+
scale[n] *= 1. / rescale[n]
|
| 1145 |
+
dxy = np.maximum(0, np.array([Lx * scale[n] - xy[1],
|
| 1146 |
+
Ly * scale[n] - xy[0]]))
|
| 1147 |
+
dxy = (np.random.rand(2,) - .5) * dxy
|
| 1148 |
+
|
| 1149 |
+
# create affine transform
|
| 1150 |
+
cc = np.array([Lx / 2, Ly / 2])
|
| 1151 |
+
cc1 = cc - np.array([Lx - xy[1], Ly - xy[0]]) / 2 + dxy
|
| 1152 |
+
pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
|
| 1153 |
+
pts2 = np.float32([
|
| 1154 |
+
cc1,
|
| 1155 |
+
cc1 + scale[n] * np.array([np.cos(theta), np.sin(theta)]),
|
| 1156 |
+
cc1 + scale[n] *
|
| 1157 |
+
np.array([np.cos(np.pi / 2 + theta),
|
| 1158 |
+
np.sin(np.pi / 2 + theta)])
|
| 1159 |
+
])
|
| 1160 |
+
M = cv2.getAffineTransform(pts1, pts2)
|
| 1161 |
+
|
| 1162 |
+
img = X[n].copy()
|
| 1163 |
+
if Y is not None:
|
| 1164 |
+
labels = Y[n].copy()
|
| 1165 |
+
if labels.ndim < 3:
|
| 1166 |
+
labels = labels[np.newaxis, :, :]
|
| 1167 |
+
if feat is not None:
|
| 1168 |
+
feats = feat[n].copy()
|
| 1169 |
+
if feats.ndim < 3:
|
| 1170 |
+
feats = feats[np.newaxis, :, :]
|
| 1171 |
+
|
| 1172 |
+
if do_3D:
|
| 1173 |
+
Lz = X[n].shape[-3]
|
| 1174 |
+
flip_z = np.random.rand() > .5
|
| 1175 |
+
lz = int(np.round(zcrop / scale[n]))
|
| 1176 |
+
iz = np.random.randint(0, Lz - lz)
|
| 1177 |
+
img = img[:,iz:iz + lz,:,:]
|
| 1178 |
+
if Y is not None:
|
| 1179 |
+
labels = labels[:,iz:iz + lz,:,:]
|
| 1180 |
+
if feat is not None:
|
| 1181 |
+
feats = feats[:,iz:iz + lz,:,:]
|
| 1182 |
+
|
| 1183 |
+
if do_flip:
|
| 1184 |
+
if flip:
|
| 1185 |
+
img = img[..., ::-1]
|
| 1186 |
+
if Y is not None:
|
| 1187 |
+
labels = labels[..., ::-1]
|
| 1188 |
+
if nt > 1 and not unet:
|
| 1189 |
+
labels[-1] = -labels[-1]
|
| 1190 |
+
if feat is not None:
|
| 1191 |
+
feats = feats[..., ::-1]
|
| 1192 |
+
if do_3D and flip_z:
|
| 1193 |
+
img = img[:, ::-1]
|
| 1194 |
+
if Y is not None:
|
| 1195 |
+
labels = labels[:,::-1]
|
| 1196 |
+
if nt > 1 and not unet:
|
| 1197 |
+
labels[-3] = -labels[-3]
|
| 1198 |
+
if feat is not None:
|
| 1199 |
+
feats = feats[:, ::-1]
|
| 1200 |
+
|
| 1201 |
+
for k in range(nchan):
|
| 1202 |
+
if do_3D:
|
| 1203 |
+
img0 = np.zeros((lz, xy[0], xy[1]), "float32")
|
| 1204 |
+
for z in range(lz):
|
| 1205 |
+
I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]),
|
| 1206 |
+
flags=cv2.INTER_LINEAR)
|
| 1207 |
+
img0[z] = I
|
| 1208 |
+
if scale[n] != 1.0:
|
| 1209 |
+
for y in range(imgi.shape[-2]):
|
| 1210 |
+
imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop),
|
| 1211 |
+
interpolation=cv2.INTER_LINEAR)
|
| 1212 |
+
else:
|
| 1213 |
+
imgi[n, k] = img0
|
| 1214 |
+
else:
|
| 1215 |
+
I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
|
| 1216 |
+
imgi[n, k] = I
|
| 1217 |
+
|
| 1218 |
+
if Y is not None:
|
| 1219 |
+
for k in range(nt):
|
| 1220 |
+
flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR
|
| 1221 |
+
if do_3D:
|
| 1222 |
+
lbl0 = np.zeros((lz, xy[0], xy[1]), "float32")
|
| 1223 |
+
for z in range(lz):
|
| 1224 |
+
I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]),
|
| 1225 |
+
flags=flag)
|
| 1226 |
+
lbl0[z] = I
|
| 1227 |
+
if scale[n] != 1.0:
|
| 1228 |
+
for y in range(lbl.shape[-2]):
|
| 1229 |
+
lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop),
|
| 1230 |
+
interpolation=flag)
|
| 1231 |
+
else:
|
| 1232 |
+
lbl[n, k] = lbl0
|
| 1233 |
+
else:
|
| 1234 |
+
lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag)
|
| 1235 |
+
|
| 1236 |
+
if nt > 1 and not unet:
|
| 1237 |
+
v1 = lbl[n, -1].copy()
|
| 1238 |
+
v2 = lbl[n, -2].copy()
|
| 1239 |
+
lbl[n, -2] = (-v1 * np.sin(-theta) + v2 * np.cos(-theta))
|
| 1240 |
+
lbl[n, -1] = (v1 * np.cos(-theta) + v2 * np.sin(-theta))
|
| 1241 |
+
|
| 1242 |
+
if feat is not None:
|
| 1243 |
+
for k in range(nf):
|
| 1244 |
+
if do_3D:
|
| 1245 |
+
feat0 = np.zeros((lz, xy[0], xy[1]), "float32")
|
| 1246 |
+
for z in range(lz):
|
| 1247 |
+
I = cv2.warpAffine(feats[k, z], M, (xy[1], xy[0]),
|
| 1248 |
+
flags=cv2.INTER_LINEAR)
|
| 1249 |
+
feat0[z] = I
|
| 1250 |
+
if scale[n] != 1.0:
|
| 1251 |
+
for y in range(feat_out.shape[-2]):
|
| 1252 |
+
feat_out[n, k, :, y] = cv2.resize(feat0[:, y], (xy[1], zcrop),
|
| 1253 |
+
interpolation=cv2.INTER_LINEAR)
|
| 1254 |
+
else:
|
| 1255 |
+
feat_out[n, k] = feat0
|
| 1256 |
+
else:
|
| 1257 |
+
feat_out[n, k] = cv2.warpAffine(feats[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR)
|
| 1258 |
+
|
| 1259 |
+
|
| 1260 |
+
|
| 1261 |
+
return imgi, lbl, feat_out, scale
|
models/seg_post_model/cellpose/utils.py
ADDED
|
@@ -0,0 +1,667 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import os, tempfile, shutil, io
|
| 6 |
+
from tqdm import tqdm, trange
|
| 7 |
+
from urllib.request import urlopen
|
| 8 |
+
import cv2
|
| 9 |
+
from scipy.ndimage import find_objects, gaussian_filter, generate_binary_structure, label
|
| 10 |
+
from scipy.spatial import ConvexHull
|
| 11 |
+
import numpy as np
|
| 12 |
+
import colorsys
|
| 13 |
+
import fastremap
|
| 14 |
+
import fill_voids
|
| 15 |
+
from multiprocessing import Pool, cpu_count
|
| 16 |
+
# try:
|
| 17 |
+
# from cellpose import metrics
|
| 18 |
+
# except:
|
| 19 |
+
# import metrics as metrics
|
| 20 |
+
from models.seg_post_model.cellpose import metrics
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
from skimage.morphology import remove_small_holes
|
| 24 |
+
SKIMAGE_ENABLED = True
|
| 25 |
+
except:
|
| 26 |
+
SKIMAGE_ENABLED = False
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class TqdmToLogger(io.StringIO):
|
| 30 |
+
"""
|
| 31 |
+
Output stream for TQDM which will output to logger module instead of
|
| 32 |
+
the StdOut.
|
| 33 |
+
"""
|
| 34 |
+
logger = None
|
| 35 |
+
level = None
|
| 36 |
+
buf = ""
|
| 37 |
+
|
| 38 |
+
def __init__(self, logger, level=None):
|
| 39 |
+
super(TqdmToLogger, self).__init__()
|
| 40 |
+
self.logger = logger
|
| 41 |
+
self.level = level or logging.INFO
|
| 42 |
+
|
| 43 |
+
def write(self, buf):
|
| 44 |
+
self.buf = buf.strip("\r\n\t ")
|
| 45 |
+
|
| 46 |
+
def flush(self):
|
| 47 |
+
self.logger.log(self.level, self.buf)
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def rgb_to_hsv(arr):
|
| 51 |
+
rgb_to_hsv_channels = np.vectorize(colorsys.rgb_to_hsv)
|
| 52 |
+
r, g, b = np.rollaxis(arr, axis=-1)
|
| 53 |
+
h, s, v = rgb_to_hsv_channels(r, g, b)
|
| 54 |
+
hsv = np.stack((h, s, v), axis=-1)
|
| 55 |
+
return hsv
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def hsv_to_rgb(arr):
|
| 59 |
+
hsv_to_rgb_channels = np.vectorize(colorsys.hsv_to_rgb)
|
| 60 |
+
h, s, v = np.rollaxis(arr, axis=-1)
|
| 61 |
+
r, g, b = hsv_to_rgb_channels(h, s, v)
|
| 62 |
+
rgb = np.stack((r, g, b), axis=-1)
|
| 63 |
+
return rgb
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def download_url_to_file(url, dst, progress=True):
|
| 67 |
+
r"""Download object at the given URL to a local path.
|
| 68 |
+
Thanks to torch, slightly modified
|
| 69 |
+
Args:
|
| 70 |
+
url (string): URL of the object to download
|
| 71 |
+
dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file`
|
| 72 |
+
progress (bool, optional): whether or not to display a progress bar to stderr
|
| 73 |
+
Default: True
|
| 74 |
+
"""
|
| 75 |
+
file_size = None
|
| 76 |
+
import ssl
|
| 77 |
+
ssl._create_default_https_context = ssl._create_unverified_context
|
| 78 |
+
u = urlopen(url)
|
| 79 |
+
meta = u.info()
|
| 80 |
+
if hasattr(meta, "getheaders"):
|
| 81 |
+
content_length = meta.getheaders("Content-Length")
|
| 82 |
+
else:
|
| 83 |
+
content_length = meta.get_all("Content-Length")
|
| 84 |
+
if content_length is not None and len(content_length) > 0:
|
| 85 |
+
file_size = int(content_length[0])
|
| 86 |
+
# We deliberately save it in a temp file and move it after
|
| 87 |
+
dst = os.path.expanduser(dst)
|
| 88 |
+
dst_dir = os.path.dirname(dst)
|
| 89 |
+
f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir)
|
| 90 |
+
try:
|
| 91 |
+
with tqdm(total=file_size, disable=not progress, unit="B", unit_scale=True,
|
| 92 |
+
unit_divisor=1024) as pbar:
|
| 93 |
+
while True:
|
| 94 |
+
buffer = u.read(8192)
|
| 95 |
+
if len(buffer) == 0:
|
| 96 |
+
break
|
| 97 |
+
f.write(buffer)
|
| 98 |
+
pbar.update(len(buffer))
|
| 99 |
+
f.close()
|
| 100 |
+
shutil.move(f.name, dst)
|
| 101 |
+
finally:
|
| 102 |
+
f.close()
|
| 103 |
+
if os.path.exists(f.name):
|
| 104 |
+
os.remove(f.name)
|
| 105 |
+
|
| 106 |
+
|
| 107 |
+
def distance_to_boundary(masks):
|
| 108 |
+
"""Get the distance to the boundary of mask pixels.
|
| 109 |
+
|
| 110 |
+
Args:
|
| 111 |
+
masks (int, 2D or 3D array): The masks array. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
dist_to_bound (2D or 3D array): The distance to the boundary. Size [Ly x Lx] or [Lz x Ly x Lx].
|
| 115 |
+
|
| 116 |
+
Raises:
|
| 117 |
+
ValueError: If the masks array is not 2D or 3D.
|
| 118 |
+
|
| 119 |
+
"""
|
| 120 |
+
if masks.ndim > 3 or masks.ndim < 2:
|
| 121 |
+
raise ValueError("distance_to_boundary takes 2D or 3D array, not %dD array" %
|
| 122 |
+
masks.ndim)
|
| 123 |
+
dist_to_bound = np.zeros(masks.shape, np.float64)
|
| 124 |
+
|
| 125 |
+
if masks.ndim == 3:
|
| 126 |
+
for i in range(masks.shape[0]):
|
| 127 |
+
dist_to_bound[i] = distance_to_boundary(masks[i])
|
| 128 |
+
return dist_to_bound
|
| 129 |
+
else:
|
| 130 |
+
slices = find_objects(masks)
|
| 131 |
+
for i, si in enumerate(slices):
|
| 132 |
+
if si is not None:
|
| 133 |
+
sr, sc = si
|
| 134 |
+
mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
|
| 135 |
+
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 136 |
+
cv2.CHAIN_APPROX_NONE)
|
| 137 |
+
pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
|
| 138 |
+
ypix, xpix = np.nonzero(mask)
|
| 139 |
+
min_dist = ((ypix[:, np.newaxis] - pvr)**2 +
|
| 140 |
+
(xpix[:, np.newaxis] - pvc)**2).min(axis=1)
|
| 141 |
+
dist_to_bound[ypix + sr.start, xpix + sc.start] = min_dist
|
| 142 |
+
return dist_to_bound
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def masks_to_edges(masks, threshold=1.0):
|
| 146 |
+
"""Get edges of masks as a 0-1 array.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
|
| 150 |
+
threshold (float, optional): Threshold value for distance to boundary. Defaults to 1.0.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
edges (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are edge pixels.
|
| 154 |
+
"""
|
| 155 |
+
dist_to_bound = distance_to_boundary(masks)
|
| 156 |
+
edges = (dist_to_bound < threshold) * (masks > 0)
|
| 157 |
+
return edges
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
def remove_edge_masks(masks, change_index=True):
|
| 161 |
+
"""Removes masks with pixels on the edge of the image.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
masks (int, 2D or 3D array): The masks to be processed. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
|
| 165 |
+
change_index (bool, optional): If True, after removing masks, changes the indexing so that there are no missing label numbers. Defaults to True.
|
| 166 |
+
|
| 167 |
+
Returns:
|
| 168 |
+
outlines (2D or 3D array): The processed masks. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels.
|
| 169 |
+
"""
|
| 170 |
+
slices = find_objects(masks.astype(int))
|
| 171 |
+
for i, si in enumerate(slices):
|
| 172 |
+
remove = False
|
| 173 |
+
if si is not None:
|
| 174 |
+
for d, sid in enumerate(si):
|
| 175 |
+
if sid.start == 0 or sid.stop == masks.shape[d]:
|
| 176 |
+
remove = True
|
| 177 |
+
break
|
| 178 |
+
if remove:
|
| 179 |
+
masks[si][masks[si] == i + 1] = 0
|
| 180 |
+
shape = masks.shape
|
| 181 |
+
if change_index:
|
| 182 |
+
_, masks = np.unique(masks, return_inverse=True)
|
| 183 |
+
masks = np.reshape(masks, shape).astype(np.int32)
|
| 184 |
+
|
| 185 |
+
return masks
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def masks_to_outlines(masks):
|
| 189 |
+
"""Get outlines of masks as a 0-1 array.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
outlines (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are outlines.
|
| 196 |
+
"""
|
| 197 |
+
if masks.ndim > 3 or masks.ndim < 2:
|
| 198 |
+
raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
|
| 199 |
+
masks.ndim)
|
| 200 |
+
outlines = np.zeros(masks.shape, bool)
|
| 201 |
+
|
| 202 |
+
if masks.ndim == 3:
|
| 203 |
+
for i in range(masks.shape[0]):
|
| 204 |
+
outlines[i] = masks_to_outlines(masks[i])
|
| 205 |
+
return outlines
|
| 206 |
+
else:
|
| 207 |
+
slices = find_objects(masks.astype(int))
|
| 208 |
+
for i, si in enumerate(slices):
|
| 209 |
+
if si is not None:
|
| 210 |
+
sr, sc = si
|
| 211 |
+
mask = (masks[sr, sc] == (i + 1)).astype(np.uint8)
|
| 212 |
+
contours = cv2.findContours(mask, cv2.RETR_EXTERNAL,
|
| 213 |
+
cv2.CHAIN_APPROX_NONE)
|
| 214 |
+
pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T
|
| 215 |
+
vr, vc = pvr + sr.start, pvc + sc.start
|
| 216 |
+
outlines[vr, vc] = 1
|
| 217 |
+
return outlines
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def outlines_list(masks, multiprocessing_threshold=1000, multiprocessing=None):
|
| 221 |
+
"""Get outlines of masks as a list to loop over for plotting.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
masks (ndarray): Array of masks.
|
| 225 |
+
multiprocessing_threshold (int, optional): Threshold for enabling multiprocessing. Defaults to 1000.
|
| 226 |
+
multiprocessing (bool, optional): Flag to enable multiprocessing. Defaults to None.
|
| 227 |
+
|
| 228 |
+
Returns:
|
| 229 |
+
list: List of outlines.
|
| 230 |
+
|
| 231 |
+
Raises:
|
| 232 |
+
None
|
| 233 |
+
|
| 234 |
+
Notes:
|
| 235 |
+
- This function is a wrapper for outlines_list_single and outlines_list_multi.
|
| 236 |
+
- Multiprocessing is disabled for Windows.
|
| 237 |
+
"""
|
| 238 |
+
# default to use multiprocessing if not few_masks, but allow user to override
|
| 239 |
+
if multiprocessing is None:
|
| 240 |
+
few_masks = np.max(masks) < multiprocessing_threshold
|
| 241 |
+
multiprocessing = not few_masks
|
| 242 |
+
|
| 243 |
+
# disable multiprocessing for Windows
|
| 244 |
+
if os.name == "nt":
|
| 245 |
+
if multiprocessing:
|
| 246 |
+
logging.getLogger(__name__).warning(
|
| 247 |
+
"Multiprocessing is disabled for Windows")
|
| 248 |
+
multiprocessing = False
|
| 249 |
+
|
| 250 |
+
if multiprocessing:
|
| 251 |
+
return outlines_list_multi(masks)
|
| 252 |
+
else:
|
| 253 |
+
return outlines_list_single(masks)
|
| 254 |
+
|
| 255 |
+
|
| 256 |
+
def outlines_list_single(masks):
|
| 257 |
+
"""Get outlines of masks as a list to loop over for plotting.
|
| 258 |
+
|
| 259 |
+
Args:
|
| 260 |
+
masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 261 |
+
|
| 262 |
+
Returns:
|
| 263 |
+
list: List of outlines as pixel coordinates.
|
| 264 |
+
|
| 265 |
+
"""
|
| 266 |
+
outpix = []
|
| 267 |
+
for n in np.unique(masks)[1:]:
|
| 268 |
+
mn = masks == n
|
| 269 |
+
if mn.sum() > 0:
|
| 270 |
+
contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
|
| 271 |
+
method=cv2.CHAIN_APPROX_NONE)
|
| 272 |
+
contours = contours[-2]
|
| 273 |
+
cmax = np.argmax([c.shape[0] for c in contours])
|
| 274 |
+
pix = contours[cmax].astype(int).squeeze()
|
| 275 |
+
if len(pix) > 4:
|
| 276 |
+
outpix.append(pix)
|
| 277 |
+
else:
|
| 278 |
+
outpix.append(np.zeros((0, 2)))
|
| 279 |
+
return outpix
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
def outlines_list_multi(masks, num_processes=None):
|
| 283 |
+
"""
|
| 284 |
+
Get outlines of masks as a list to loop over for plotting.
|
| 285 |
+
|
| 286 |
+
Args:
|
| 287 |
+
masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 288 |
+
|
| 289 |
+
Returns:
|
| 290 |
+
list: List of outlines as pixel coordinates.
|
| 291 |
+
"""
|
| 292 |
+
if num_processes is None:
|
| 293 |
+
num_processes = cpu_count()
|
| 294 |
+
|
| 295 |
+
unique_masks = np.unique(masks)[1:]
|
| 296 |
+
with Pool(processes=num_processes) as pool:
|
| 297 |
+
outpix = pool.map(get_outline_multi, [(masks, n) for n in unique_masks])
|
| 298 |
+
return outpix
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
def get_outline_multi(args):
|
| 302 |
+
"""Get the outline of a specific mask in a multi-mask image.
|
| 303 |
+
|
| 304 |
+
Args:
|
| 305 |
+
args (tuple): A tuple containing the masks and the mask number.
|
| 306 |
+
|
| 307 |
+
Returns:
|
| 308 |
+
numpy.ndarray: The outline of the specified mask as an array of coordinates.
|
| 309 |
+
|
| 310 |
+
"""
|
| 311 |
+
masks, n = args
|
| 312 |
+
mn = masks == n
|
| 313 |
+
if mn.sum() > 0:
|
| 314 |
+
contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
|
| 315 |
+
method=cv2.CHAIN_APPROX_NONE)
|
| 316 |
+
contours = contours[-2]
|
| 317 |
+
cmax = np.argmax([c.shape[0] for c in contours])
|
| 318 |
+
pix = contours[cmax].astype(int).squeeze()
|
| 319 |
+
return pix if len(pix) > 4 else np.zeros((0, 2))
|
| 320 |
+
return np.zeros((0, 2))
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
def dilate_masks(masks, n_iter=5):
|
| 324 |
+
"""Dilate masks by n_iter pixels.
|
| 325 |
+
|
| 326 |
+
Args:
|
| 327 |
+
masks (ndarray): Array of masks.
|
| 328 |
+
n_iter (int, optional): Number of pixels to dilate the masks. Defaults to 5.
|
| 329 |
+
|
| 330 |
+
Returns:
|
| 331 |
+
ndarray: Dilated masks.
|
| 332 |
+
"""
|
| 333 |
+
dilated_masks = masks.copy()
|
| 334 |
+
for n in range(n_iter):
|
| 335 |
+
# define the structuring element to use for dilation
|
| 336 |
+
kernel = np.ones((3, 3), "uint8")
|
| 337 |
+
# find the distance to each mask (distances are zero within masks)
|
| 338 |
+
dist_transform = cv2.distanceTransform((dilated_masks == 0).astype("uint8"),
|
| 339 |
+
cv2.DIST_L2, 5)
|
| 340 |
+
# dilate each mask and assign to it the pixels along the border of the mask
|
| 341 |
+
# (does not allow dilation into other masks since dist_transform is zero there)
|
| 342 |
+
for i in range(1, np.max(masks) + 1):
|
| 343 |
+
mask = (dilated_masks == i).astype("uint8")
|
| 344 |
+
dilated_mask = cv2.dilate(mask, kernel, iterations=1)
|
| 345 |
+
dilated_mask = np.logical_and(dist_transform < 2, dilated_mask)
|
| 346 |
+
dilated_masks[dilated_mask > 0] = i
|
| 347 |
+
return dilated_masks
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
def get_perimeter(points):
|
| 351 |
+
"""
|
| 352 |
+
Calculate the perimeter of a set of points.
|
| 353 |
+
|
| 354 |
+
Parameters:
|
| 355 |
+
points (ndarray): An array of points with shape (npoints, ndim).
|
| 356 |
+
|
| 357 |
+
Returns:
|
| 358 |
+
float: The perimeter of the points.
|
| 359 |
+
|
| 360 |
+
"""
|
| 361 |
+
if points.shape[0] > 4:
|
| 362 |
+
points = np.append(points, points[:1], axis=0)
|
| 363 |
+
return ((np.diff(points, axis=0)**2).sum(axis=1)**0.5).sum()
|
| 364 |
+
else:
|
| 365 |
+
return 0
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
def get_mask_compactness(masks):
|
| 369 |
+
"""
|
| 370 |
+
Calculate the compactness of masks.
|
| 371 |
+
|
| 372 |
+
Parameters:
|
| 373 |
+
masks (ndarray): Binary masks representing objects.
|
| 374 |
+
|
| 375 |
+
Returns:
|
| 376 |
+
ndarray: Array of compactness values for each mask.
|
| 377 |
+
"""
|
| 378 |
+
perimeters = get_mask_perimeters(masks)
|
| 379 |
+
npoints = np.unique(masks, return_counts=True)[1][1:]
|
| 380 |
+
areas = npoints
|
| 381 |
+
compactness = 4 * np.pi * areas / perimeters**2
|
| 382 |
+
compactness[perimeters == 0] = 0
|
| 383 |
+
compactness[compactness > 1.0] = 1.0
|
| 384 |
+
return compactness
|
| 385 |
+
|
| 386 |
+
|
| 387 |
+
def get_mask_perimeters(masks):
|
| 388 |
+
"""
|
| 389 |
+
Calculate the perimeters of the given masks.
|
| 390 |
+
|
| 391 |
+
Parameters:
|
| 392 |
+
masks (numpy.ndarray): Binary masks representing objects.
|
| 393 |
+
|
| 394 |
+
Returns:
|
| 395 |
+
numpy.ndarray: Array containing the perimeters of each mask.
|
| 396 |
+
"""
|
| 397 |
+
perimeters = np.zeros(masks.max())
|
| 398 |
+
for n in range(masks.max()):
|
| 399 |
+
mn = masks == (n + 1)
|
| 400 |
+
if mn.sum() > 0:
|
| 401 |
+
contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL,
|
| 402 |
+
method=cv2.CHAIN_APPROX_NONE)[-2]
|
| 403 |
+
perimeters[n] = np.array(
|
| 404 |
+
[get_perimeter(c.astype(int).squeeze()) for c in contours]).sum()
|
| 405 |
+
|
| 406 |
+
return perimeters
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
def circleMask(d0):
|
| 410 |
+
"""
|
| 411 |
+
Creates an array with indices which are the radius of that x,y point.
|
| 412 |
+
|
| 413 |
+
Args:
|
| 414 |
+
d0 (tuple): Patch of (-d0, d0+1) over which radius is computed.
|
| 415 |
+
|
| 416 |
+
Returns:
|
| 417 |
+
tuple: A tuple containing:
|
| 418 |
+
- rs (ndarray): Array of radii with shape (2*d0[0]+1, 2*d0[1]+1).
|
| 419 |
+
- dx (ndarray): Indices of the patch along the x-axis.
|
| 420 |
+
- dy (ndarray): Indices of the patch along the y-axis.
|
| 421 |
+
"""
|
| 422 |
+
dx = np.tile(np.arange(-d0[1], d0[1] + 1), (2 * d0[0] + 1, 1))
|
| 423 |
+
dy = np.tile(np.arange(-d0[0], d0[0] + 1), (2 * d0[1] + 1, 1))
|
| 424 |
+
dy = dy.transpose()
|
| 425 |
+
|
| 426 |
+
rs = (dy**2 + dx**2)**0.5
|
| 427 |
+
return rs, dx, dy
|
| 428 |
+
|
| 429 |
+
|
| 430 |
+
def get_mask_stats(masks_true):
|
| 431 |
+
"""
|
| 432 |
+
Calculate various statistics for the given binary masks.
|
| 433 |
+
|
| 434 |
+
Parameters:
|
| 435 |
+
masks_true (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 436 |
+
|
| 437 |
+
Returns:
|
| 438 |
+
convexity (ndarray): Convexity values for each mask.
|
| 439 |
+
solidity (ndarray): Solidity values for each mask.
|
| 440 |
+
compactness (ndarray): Compactness values for each mask.
|
| 441 |
+
"""
|
| 442 |
+
mask_perimeters = get_mask_perimeters(masks_true)
|
| 443 |
+
|
| 444 |
+
# disk for compactness
|
| 445 |
+
rs, dy, dx = circleMask(np.array([100, 100]))
|
| 446 |
+
rsort = np.sort(rs.flatten())
|
| 447 |
+
|
| 448 |
+
# area for solidity
|
| 449 |
+
npoints = np.unique(masks_true, return_counts=True)[1][1:]
|
| 450 |
+
areas = npoints - mask_perimeters / 2 - 1
|
| 451 |
+
|
| 452 |
+
compactness = np.zeros(masks_true.max())
|
| 453 |
+
convexity = np.zeros(masks_true.max())
|
| 454 |
+
solidity = np.zeros(masks_true.max())
|
| 455 |
+
convex_perimeters = np.zeros(masks_true.max())
|
| 456 |
+
convex_areas = np.zeros(masks_true.max())
|
| 457 |
+
for ic in range(masks_true.max()):
|
| 458 |
+
points = np.array(np.nonzero(masks_true == (ic + 1))).T
|
| 459 |
+
if len(points) > 15 and mask_perimeters[ic] > 0:
|
| 460 |
+
med = np.median(points, axis=0)
|
| 461 |
+
# compute compactness of ROI
|
| 462 |
+
r2 = ((points - med)**2).sum(axis=1)**0.5
|
| 463 |
+
compactness[ic] = (rsort[:r2.size].mean() + 1e-10) / r2.mean()
|
| 464 |
+
try:
|
| 465 |
+
hull = ConvexHull(points)
|
| 466 |
+
convex_perimeters[ic] = hull.area
|
| 467 |
+
convex_areas[ic] = hull.volume
|
| 468 |
+
except:
|
| 469 |
+
convex_perimeters[ic] = 0
|
| 470 |
+
|
| 471 |
+
convexity[mask_perimeters > 0.0] = (convex_perimeters[mask_perimeters > 0.0] /
|
| 472 |
+
mask_perimeters[mask_perimeters > 0.0])
|
| 473 |
+
solidity[convex_areas > 0.0] = (areas[convex_areas > 0.0] /
|
| 474 |
+
convex_areas[convex_areas > 0.0])
|
| 475 |
+
convexity = np.clip(convexity, 0.0, 1.0)
|
| 476 |
+
solidity = np.clip(solidity, 0.0, 1.0)
|
| 477 |
+
compactness = np.clip(compactness, 0.0, 1.0)
|
| 478 |
+
return convexity, solidity, compactness
|
| 479 |
+
|
| 480 |
+
|
| 481 |
+
def get_masks_unet(output, cell_threshold=0, boundary_threshold=0):
|
| 482 |
+
"""Create masks using cell probability and cell boundary.
|
| 483 |
+
|
| 484 |
+
Args:
|
| 485 |
+
output (ndarray): The output array containing cell probability and cell boundary.
|
| 486 |
+
cell_threshold (float, optional): The threshold value for cell probability. Defaults to 0.
|
| 487 |
+
boundary_threshold (float, optional): The threshold value for cell boundary. Defaults to 0.
|
| 488 |
+
|
| 489 |
+
Returns:
|
| 490 |
+
ndarray: The masks representing the segmented cells.
|
| 491 |
+
|
| 492 |
+
"""
|
| 493 |
+
cells = (output[..., 1] - output[..., 0]) > cell_threshold
|
| 494 |
+
selem = generate_binary_structure(cells.ndim, connectivity=1)
|
| 495 |
+
labels, nlabels = label(cells, selem)
|
| 496 |
+
|
| 497 |
+
if output.shape[-1] > 2:
|
| 498 |
+
slices = find_objects(labels)
|
| 499 |
+
dists = 10000 * np.ones(labels.shape, np.float32)
|
| 500 |
+
mins = np.zeros(labels.shape, np.int32)
|
| 501 |
+
borders = np.logical_and(~(labels > 0), output[..., 2] > boundary_threshold)
|
| 502 |
+
pad = 10
|
| 503 |
+
for i, slc in enumerate(slices):
|
| 504 |
+
if slc is not None:
|
| 505 |
+
slc_pad = tuple([
|
| 506 |
+
slice(max(0, sli.start - pad), min(labels.shape[j], sli.stop + pad))
|
| 507 |
+
for j, sli in enumerate(slc)
|
| 508 |
+
])
|
| 509 |
+
msk = (labels[slc_pad] == (i + 1)).astype(np.float32)
|
| 510 |
+
msk = 1 - gaussian_filter(msk, 5)
|
| 511 |
+
dists[slc_pad] = np.minimum(dists[slc_pad], msk)
|
| 512 |
+
mins[slc_pad][dists[slc_pad] == msk] = (i + 1)
|
| 513 |
+
labels[labels == 0] = borders[labels == 0] * mins[labels == 0]
|
| 514 |
+
|
| 515 |
+
masks = labels
|
| 516 |
+
shape0 = masks.shape
|
| 517 |
+
_, masks = np.unique(masks, return_inverse=True)
|
| 518 |
+
masks = np.reshape(masks, shape0)
|
| 519 |
+
return masks
|
| 520 |
+
|
| 521 |
+
|
| 522 |
+
def stitch3D(masks, stitch_threshold=0.25):
|
| 523 |
+
"""
|
| 524 |
+
Stitch 2D masks into a 3D volume using a stitch_threshold on IOU.
|
| 525 |
+
|
| 526 |
+
Args:
|
| 527 |
+
masks (list or ndarray): List of 2D masks.
|
| 528 |
+
stitch_threshold (float, optional): Threshold value for stitching. Defaults to 0.25.
|
| 529 |
+
|
| 530 |
+
Returns:
|
| 531 |
+
list: List of stitched 3D masks.
|
| 532 |
+
"""
|
| 533 |
+
mmax = masks[0].max()
|
| 534 |
+
empty = 0
|
| 535 |
+
for i in trange(len(masks) - 1):
|
| 536 |
+
iou = metrics._intersection_over_union(masks[i + 1], masks[i])[1:, 1:]
|
| 537 |
+
if not iou.size and empty == 0:
|
| 538 |
+
masks[i + 1] = masks[i + 1]
|
| 539 |
+
mmax = masks[i + 1].max()
|
| 540 |
+
elif not iou.size and not empty == 0:
|
| 541 |
+
icount = masks[i + 1].max()
|
| 542 |
+
istitch = np.arange(mmax + 1, mmax + icount + 1, 1, masks.dtype)
|
| 543 |
+
mmax += icount
|
| 544 |
+
istitch = np.append(np.array(0), istitch)
|
| 545 |
+
masks[i + 1] = istitch[masks[i + 1]]
|
| 546 |
+
else:
|
| 547 |
+
iou[iou < stitch_threshold] = 0.0
|
| 548 |
+
iou[iou < iou.max(axis=0)] = 0.0
|
| 549 |
+
istitch = iou.argmax(axis=1) + 1
|
| 550 |
+
ino = np.nonzero(iou.max(axis=1) == 0.0)[0]
|
| 551 |
+
istitch[ino] = np.arange(mmax + 1, mmax + len(ino) + 1, 1, masks.dtype)
|
| 552 |
+
mmax += len(ino)
|
| 553 |
+
istitch = np.append(np.array(0), istitch)
|
| 554 |
+
masks[i + 1] = istitch[masks[i + 1]]
|
| 555 |
+
empty = 1
|
| 556 |
+
|
| 557 |
+
return masks
|
| 558 |
+
|
| 559 |
+
|
| 560 |
+
def diameters(masks):
|
| 561 |
+
"""
|
| 562 |
+
Calculate the diameters of the objects in the given masks.
|
| 563 |
+
|
| 564 |
+
Parameters:
|
| 565 |
+
masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 566 |
+
|
| 567 |
+
Returns:
|
| 568 |
+
tuple: A tuple containing the median diameter and an array of diameters for each object.
|
| 569 |
+
|
| 570 |
+
Examples:
|
| 571 |
+
>>> masks = np.array([[0, 1, 1], [1, 0, 0], [1, 1, 0]])
|
| 572 |
+
>>> diameters(masks)
|
| 573 |
+
(1.0, array([1.41421356, 1.0, 1.0]))
|
| 574 |
+
"""
|
| 575 |
+
uniq, counts = fastremap.unique(masks.astype("int32"), return_counts=True)
|
| 576 |
+
counts = counts[1:]
|
| 577 |
+
md = np.median(counts**0.5)
|
| 578 |
+
if np.isnan(md):
|
| 579 |
+
md = 0
|
| 580 |
+
md /= (np.pi**0.5) / 2
|
| 581 |
+
return md, counts**0.5
|
| 582 |
+
|
| 583 |
+
|
| 584 |
+
def radius_distribution(masks, bins):
|
| 585 |
+
"""
|
| 586 |
+
Calculate the radius distribution of masks.
|
| 587 |
+
|
| 588 |
+
Args:
|
| 589 |
+
masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 590 |
+
bins (int): Number of bins for the histogram.
|
| 591 |
+
|
| 592 |
+
Returns:
|
| 593 |
+
A tuple containing a normalized histogram of radii, median radius, array of radii.
|
| 594 |
+
|
| 595 |
+
"""
|
| 596 |
+
unique, counts = np.unique(masks, return_counts=True)
|
| 597 |
+
counts = counts[unique != 0]
|
| 598 |
+
nb, _ = np.histogram((counts**0.5) * 0.5, bins)
|
| 599 |
+
nb = nb.astype(np.float32)
|
| 600 |
+
if nb.sum() > 0:
|
| 601 |
+
nb = nb / nb.sum()
|
| 602 |
+
md = np.median(counts**0.5) * 0.5
|
| 603 |
+
if np.isnan(md):
|
| 604 |
+
md = 0
|
| 605 |
+
md /= (np.pi**0.5) / 2
|
| 606 |
+
return nb, md, (counts**0.5) / 2
|
| 607 |
+
|
| 608 |
+
|
| 609 |
+
def size_distribution(masks):
|
| 610 |
+
"""
|
| 611 |
+
Calculates the size distribution of masks.
|
| 612 |
+
|
| 613 |
+
Args:
|
| 614 |
+
masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...)
|
| 615 |
+
|
| 616 |
+
Returns:
|
| 617 |
+
float: The ratio of the 25th percentile of mask sizes to the 75th percentile of mask sizes.
|
| 618 |
+
"""
|
| 619 |
+
counts = np.unique(masks, return_counts=True)[1][1:]
|
| 620 |
+
return np.percentile(counts, 25) / np.percentile(counts, 75)
|
| 621 |
+
|
| 622 |
+
|
| 623 |
+
def fill_holes_and_remove_small_masks(masks, min_size=15):
|
| 624 |
+
""" Fills holes in masks (2D/3D) and discards masks smaller than min_size.
|
| 625 |
+
|
| 626 |
+
This function fills holes in each mask using fill_voids.fill.
|
| 627 |
+
It also removes masks that are smaller than the specified min_size.
|
| 628 |
+
|
| 629 |
+
Parameters:
|
| 630 |
+
masks (ndarray): Int, 2D or 3D array of labelled masks.
|
| 631 |
+
0 represents no mask, while positive integers represent mask labels.
|
| 632 |
+
The size can be [Ly x Lx] or [Lz x Ly x Lx].
|
| 633 |
+
min_size (int, optional): Minimum number of pixels per mask.
|
| 634 |
+
Masks smaller than min_size will be removed.
|
| 635 |
+
Set to -1 to turn off this functionality. Default is 15.
|
| 636 |
+
|
| 637 |
+
Returns:
|
| 638 |
+
ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed.
|
| 639 |
+
0 represents no mask, while positive integers represent mask labels.
|
| 640 |
+
The size is [Ly x Lx] or [Lz x Ly x Lx].
|
| 641 |
+
"""
|
| 642 |
+
|
| 643 |
+
if masks.ndim > 3 or masks.ndim < 2:
|
| 644 |
+
raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" %
|
| 645 |
+
masks.ndim)
|
| 646 |
+
|
| 647 |
+
# Filter small masks
|
| 648 |
+
if min_size > 0:
|
| 649 |
+
counts = fastremap.unique(masks, return_counts=True)[1][1:]
|
| 650 |
+
masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
|
| 651 |
+
fastremap.renumber(masks, in_place=True)
|
| 652 |
+
|
| 653 |
+
slices = find_objects(masks)
|
| 654 |
+
j = 0
|
| 655 |
+
for i, slc in enumerate(slices):
|
| 656 |
+
if slc is not None:
|
| 657 |
+
msk = masks[slc] == (i + 1)
|
| 658 |
+
msk = fill_voids.fill(msk)
|
| 659 |
+
masks[slc][msk] = (j + 1)
|
| 660 |
+
j += 1
|
| 661 |
+
|
| 662 |
+
if min_size > 0:
|
| 663 |
+
counts = fastremap.unique(masks, return_counts=True)[1][1:]
|
| 664 |
+
masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1)
|
| 665 |
+
fastremap.renumber(masks, in_place=True)
|
| 666 |
+
|
| 667 |
+
return masks
|
models/seg_post_model/cellpose/version.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
from importlib.metadata import PackageNotFoundError, version
|
| 5 |
+
import sys
|
| 6 |
+
from platform import python_version
|
| 7 |
+
import torch
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
version = version("cellpose")
|
| 11 |
+
except PackageNotFoundError:
|
| 12 |
+
version = "unknown"
|
| 13 |
+
|
| 14 |
+
version_str = f"""
|
| 15 |
+
cellpose version: \t{version}
|
| 16 |
+
platform: \t{sys.platform}
|
| 17 |
+
python version: \t{python_version()}
|
| 18 |
+
torch version: \t{torch.__version__}"""
|
models/seg_post_model/cellpose/vit_sam.py
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from segment_anything import sam_model_registry
|
| 7 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
class Transformer(nn.Module):
|
| 12 |
+
def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4,
|
| 13 |
+
checkpoint=None, dtype=torch.float32):
|
| 14 |
+
super(Transformer, self).__init__()
|
| 15 |
+
"""
|
| 16 |
+
print(self.encoder.patch_embed)
|
| 17 |
+
PatchEmbed(
|
| 18 |
+
(proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
|
| 19 |
+
)
|
| 20 |
+
print(self.encoder.neck)
|
| 21 |
+
Sequential(
|
| 22 |
+
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
| 23 |
+
(1): LayerNorm2d()
|
| 24 |
+
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
| 25 |
+
(3): LayerNorm2d()
|
| 26 |
+
)
|
| 27 |
+
"""
|
| 28 |
+
# instantiate the vit model, default to not loading SAM
|
| 29 |
+
# checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM
|
| 30 |
+
self.encoder = sam_model_registry[backbone](checkpoint).image_encoder
|
| 31 |
+
w = self.encoder.patch_embed.proj.weight.detach()
|
| 32 |
+
nchan = w.shape[0]
|
| 33 |
+
|
| 34 |
+
# change token size to ps x ps
|
| 35 |
+
self.ps = ps
|
| 36 |
+
self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
|
| 37 |
+
self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps]
|
| 38 |
+
|
| 39 |
+
# adjust position embeddings for new bsize and new token size
|
| 40 |
+
ds = (1024 // 16) // (bsize // ps)
|
| 41 |
+
self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True)
|
| 42 |
+
|
| 43 |
+
# readout weights for nout output channels
|
| 44 |
+
# if nout is changed, weights will not load correctly from pretrained Cellpose-SAM
|
| 45 |
+
self.nout = nout
|
| 46 |
+
self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1)
|
| 47 |
+
|
| 48 |
+
# W2 reshapes token space to pixel space, not trainable
|
| 49 |
+
self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps),
|
| 50 |
+
requires_grad=False)
|
| 51 |
+
|
| 52 |
+
# fraction of layers to drop at random during training
|
| 53 |
+
self.rdrop = rdrop
|
| 54 |
+
|
| 55 |
+
# average diameter of ROIs from training images from fine-tuning
|
| 56 |
+
self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False)
|
| 57 |
+
# average diameter of ROIs during main training
|
| 58 |
+
self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False)
|
| 59 |
+
|
| 60 |
+
# set attention to global in every layer
|
| 61 |
+
for blk in self.encoder.blocks:
|
| 62 |
+
blk.window_size = 0
|
| 63 |
+
|
| 64 |
+
self.dtype = dtype
|
| 65 |
+
|
| 66 |
+
def forward(self, x, feat=None):
|
| 67 |
+
# same progression as SAM until readout
|
| 68 |
+
x = self.encoder.patch_embed(x)
|
| 69 |
+
if feat is not None:
|
| 70 |
+
feat = self.encoder.patch_embed(feat)
|
| 71 |
+
x = x + x * feat * 0.5
|
| 72 |
+
|
| 73 |
+
if self.encoder.pos_embed is not None:
|
| 74 |
+
x = x + self.encoder.pos_embed
|
| 75 |
+
|
| 76 |
+
if self.training and self.rdrop > 0:
|
| 77 |
+
nlay = len(self.encoder.blocks)
|
| 78 |
+
rdrop = (torch.rand((len(x), nlay), device=x.device) <
|
| 79 |
+
torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
|
| 80 |
+
for i, blk in enumerate(self.encoder.blocks):
|
| 81 |
+
mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 82 |
+
x = x * mask + blk(x) * (1-mask)
|
| 83 |
+
else:
|
| 84 |
+
for blk in self.encoder.blocks:
|
| 85 |
+
x = blk(x)
|
| 86 |
+
|
| 87 |
+
x = self.encoder.neck(x.permute(0, 3, 1, 2))
|
| 88 |
+
|
| 89 |
+
# readout is changed here
|
| 90 |
+
x1 = self.out(x)
|
| 91 |
+
x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
|
| 92 |
+
|
| 93 |
+
# maintain the second output of feature size 256 for backwards compatibility
|
| 94 |
+
|
| 95 |
+
return x1, torch.randn((x.shape[0], 256), device=x.device)
|
| 96 |
+
|
| 97 |
+
def load_model(self, PATH, device, strict = False):
|
| 98 |
+
state_dict = torch.load(PATH, map_location = device, weights_only=True)
|
| 99 |
+
keys = [k for k in state_dict.keys()]
|
| 100 |
+
if keys[0][:7] == "module.":
|
| 101 |
+
from collections import OrderedDict
|
| 102 |
+
new_state_dict = OrderedDict()
|
| 103 |
+
for k, v in state_dict.items():
|
| 104 |
+
name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
|
| 105 |
+
new_state_dict[name] = v
|
| 106 |
+
self.load_state_dict(new_state_dict, strict = strict)
|
| 107 |
+
else:
|
| 108 |
+
self.load_state_dict(state_dict, strict = strict)
|
| 109 |
+
|
| 110 |
+
if self.dtype != torch.float32:
|
| 111 |
+
self = self.to(self.dtype)
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
@property
|
| 115 |
+
def device(self):
|
| 116 |
+
"""
|
| 117 |
+
Get the device of the model.
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
torch.device: The device of the model.
|
| 121 |
+
"""
|
| 122 |
+
return next(self.parameters()).device
|
| 123 |
+
|
| 124 |
+
def save_model(self, filename):
|
| 125 |
+
"""
|
| 126 |
+
Save the model to a file.
|
| 127 |
+
|
| 128 |
+
Args:
|
| 129 |
+
filename (str): The path to the file where the model will be saved.
|
| 130 |
+
"""
|
| 131 |
+
torch.save(self.state_dict(), filename)
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
class CPnetBioImageIO(Transformer):
|
| 136 |
+
"""
|
| 137 |
+
A subclass of the CP-SAM model compatible with the BioImage.IO Spec.
|
| 138 |
+
|
| 139 |
+
This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
|
| 140 |
+
allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
|
| 141 |
+
"""
|
| 142 |
+
|
| 143 |
+
def forward(self, x):
|
| 144 |
+
"""
|
| 145 |
+
Perform a forward pass of the CPnet model and return unpacked tensors.
|
| 146 |
+
|
| 147 |
+
Args:
|
| 148 |
+
x (torch.Tensor): Input tensor.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
|
| 152 |
+
"""
|
| 153 |
+
output_tensor, style_tensor, downsampled_tensors = super().forward(x)
|
| 154 |
+
return output_tensor, style_tensor, *downsampled_tensors
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def load_model(self, filename, device=None):
|
| 158 |
+
"""
|
| 159 |
+
Load the model from a file.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
filename (str): The path to the file where the model is saved.
|
| 163 |
+
device (torch.device, optional): The device to load the model on. Defaults to None.
|
| 164 |
+
"""
|
| 165 |
+
if (device is not None) and (device.type != "cpu"):
|
| 166 |
+
state_dict = torch.load(filename, map_location=device, weights_only=True)
|
| 167 |
+
else:
|
| 168 |
+
self.__init__(self.nout)
|
| 169 |
+
state_dict = torch.load(filename, map_location=torch.device("cpu"),
|
| 170 |
+
weights_only=True)
|
| 171 |
+
|
| 172 |
+
self.load_state_dict(state_dict)
|
| 173 |
+
|
| 174 |
+
def load_state_dict(self, state_dict):
|
| 175 |
+
"""
|
| 176 |
+
Load the state dictionary into the model.
|
| 177 |
+
|
| 178 |
+
This method overrides the default `load_state_dict` to handle Cellpose's custom
|
| 179 |
+
loading mechanism and ensures compatibility with BioImage.IO Core.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
state_dict (Mapping[str, Any]): A state dictionary to load into the model
|
| 183 |
+
"""
|
| 184 |
+
if state_dict["output.2.weight"].shape[0] != self.nout:
|
| 185 |
+
for name in self.state_dict():
|
| 186 |
+
if "output" not in name:
|
| 187 |
+
self.state_dict()[name].copy_(state_dict[name])
|
| 188 |
+
else:
|
| 189 |
+
super().load_state_dict(
|
| 190 |
+
{name: param for name, param in state_dict.items()},
|
| 191 |
+
strict=False)
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
|
| 195 |
+
|
models/seg_post_model/cellpose/vit_sam_new.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from segment_anything import sam_model_registry
|
| 7 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
| 8 |
+
from torch import nn
|
| 9 |
+
import torch.nn.functional as F
|
| 10 |
+
|
| 11 |
+
class Transformer(nn.Module):
|
| 12 |
+
def __init__(self, backbone="vit_l", ps=16, nout=3, bsize=256, rdrop=0.4,
|
| 13 |
+
checkpoint=None, dtype=torch.float32):
|
| 14 |
+
super(Transformer, self).__init__()
|
| 15 |
+
"""
|
| 16 |
+
print(self.encoder.patch_embed)
|
| 17 |
+
PatchEmbed(
|
| 18 |
+
(proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
|
| 19 |
+
)
|
| 20 |
+
print(self.encoder.neck)
|
| 21 |
+
Sequential(
|
| 22 |
+
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
|
| 23 |
+
(1): LayerNorm2d()
|
| 24 |
+
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
|
| 25 |
+
(3): LayerNorm2d()
|
| 26 |
+
)
|
| 27 |
+
"""
|
| 28 |
+
# instantiate the vit model, default to not loading SAM
|
| 29 |
+
# checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM
|
| 30 |
+
if checkpoint is None:
|
| 31 |
+
checkpoint = "sam_vit_l_0b3195.pth"
|
| 32 |
+
self.encoder = sam_model_registry[backbone](checkpoint).image_encoder
|
| 33 |
+
w = self.encoder.patch_embed.proj.weight.detach()
|
| 34 |
+
nchan = w.shape[0]
|
| 35 |
+
|
| 36 |
+
# change token size to ps x ps
|
| 37 |
+
self.ps = ps
|
| 38 |
+
# self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
|
| 39 |
+
# self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps]
|
| 40 |
+
|
| 41 |
+
# adjust position embeddings for new bsize and new token size
|
| 42 |
+
ds = (1024 // 16) // (bsize // ps)
|
| 43 |
+
self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True)
|
| 44 |
+
|
| 45 |
+
# readout weights for nout output channels
|
| 46 |
+
# if nout is changed, weights will not load correctly from pretrained Cellpose-SAM
|
| 47 |
+
self.nout = nout
|
| 48 |
+
self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1)
|
| 49 |
+
|
| 50 |
+
# W2 reshapes token space to pixel space, not trainable
|
| 51 |
+
self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps),
|
| 52 |
+
requires_grad=False)
|
| 53 |
+
|
| 54 |
+
# fraction of layers to drop at random during training
|
| 55 |
+
self.rdrop = rdrop
|
| 56 |
+
|
| 57 |
+
# average diameter of ROIs from training images from fine-tuning
|
| 58 |
+
self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False)
|
| 59 |
+
# average diameter of ROIs during main training
|
| 60 |
+
self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False)
|
| 61 |
+
|
| 62 |
+
# set attention to global in every layer
|
| 63 |
+
for blk in self.encoder.blocks:
|
| 64 |
+
blk.window_size = 0
|
| 65 |
+
|
| 66 |
+
self.dtype = dtype
|
| 67 |
+
|
| 68 |
+
def forward(self, x, feat=None):
|
| 69 |
+
# same progression as SAM until readout
|
| 70 |
+
x = self.encoder.patch_embed(x)
|
| 71 |
+
if feat is not None:
|
| 72 |
+
feat = self.encoder.patch_embed(feat)
|
| 73 |
+
x = x + x * feat * 0.5
|
| 74 |
+
|
| 75 |
+
if self.encoder.pos_embed is not None:
|
| 76 |
+
x = x + self.encoder.pos_embed
|
| 77 |
+
|
| 78 |
+
if self.training and self.rdrop > 0:
|
| 79 |
+
nlay = len(self.encoder.blocks)
|
| 80 |
+
rdrop = (torch.rand((len(x), nlay), device=x.device) <
|
| 81 |
+
torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
|
| 82 |
+
for i, blk in enumerate(self.encoder.blocks):
|
| 83 |
+
mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
|
| 84 |
+
x = x * mask + blk(x) * (1-mask)
|
| 85 |
+
else:
|
| 86 |
+
for blk in self.encoder.blocks:
|
| 87 |
+
x = blk(x)
|
| 88 |
+
|
| 89 |
+
x = self.encoder.neck(x.permute(0, 3, 1, 2))
|
| 90 |
+
|
| 91 |
+
# readout is changed here
|
| 92 |
+
x1 = self.out(x)
|
| 93 |
+
x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
|
| 94 |
+
|
| 95 |
+
# maintain the second output of feature size 256 for backwards compatibility
|
| 96 |
+
|
| 97 |
+
return x1, torch.randn((x.shape[0], 256), device=x.device)
|
| 98 |
+
|
| 99 |
+
def load_model(self, PATH, device, strict = False):
|
| 100 |
+
state_dict = torch.load(PATH, map_location = device, weights_only=True)
|
| 101 |
+
keys = [k for k in state_dict.keys()]
|
| 102 |
+
if keys[0][:7] == "module.":
|
| 103 |
+
from collections import OrderedDict
|
| 104 |
+
new_state_dict = OrderedDict()
|
| 105 |
+
for k, v in state_dict.items():
|
| 106 |
+
name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
|
| 107 |
+
new_state_dict[name] = v
|
| 108 |
+
self.load_state_dict(new_state_dict, strict = strict)
|
| 109 |
+
else:
|
| 110 |
+
self.load_state_dict(state_dict, strict = strict)
|
| 111 |
+
|
| 112 |
+
if self.dtype != torch.float32:
|
| 113 |
+
self = self.to(self.dtype)
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@property
|
| 117 |
+
def device(self):
|
| 118 |
+
"""
|
| 119 |
+
Get the device of the model.
|
| 120 |
+
|
| 121 |
+
Returns:
|
| 122 |
+
torch.device: The device of the model.
|
| 123 |
+
"""
|
| 124 |
+
return next(self.parameters()).device
|
| 125 |
+
|
| 126 |
+
def save_model(self, filename):
|
| 127 |
+
"""
|
| 128 |
+
Save the model to a file.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
filename (str): The path to the file where the model will be saved.
|
| 132 |
+
"""
|
| 133 |
+
torch.save(self.state_dict(), filename)
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class CPnetBioImageIO(Transformer):
|
| 138 |
+
"""
|
| 139 |
+
A subclass of the CP-SAM model compatible with the BioImage.IO Spec.
|
| 140 |
+
|
| 141 |
+
This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
|
| 142 |
+
allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
|
| 143 |
+
"""
|
| 144 |
+
|
| 145 |
+
def forward(self, x):
|
| 146 |
+
"""
|
| 147 |
+
Perform a forward pass of the CPnet model and return unpacked tensors.
|
| 148 |
+
|
| 149 |
+
Args:
|
| 150 |
+
x (torch.Tensor): Input tensor.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
|
| 154 |
+
"""
|
| 155 |
+
output_tensor, style_tensor, downsampled_tensors = super().forward(x)
|
| 156 |
+
return output_tensor, style_tensor, *downsampled_tensors
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
def load_model(self, filename, device=None):
|
| 160 |
+
"""
|
| 161 |
+
Load the model from a file.
|
| 162 |
+
|
| 163 |
+
Args:
|
| 164 |
+
filename (str): The path to the file where the model is saved.
|
| 165 |
+
device (torch.device, optional): The device to load the model on. Defaults to None.
|
| 166 |
+
"""
|
| 167 |
+
if (device is not None) and (device.type != "cpu"):
|
| 168 |
+
state_dict = torch.load(filename, map_location=device, weights_only=True)
|
| 169 |
+
else:
|
| 170 |
+
self.__init__(self.nout)
|
| 171 |
+
state_dict = torch.load(filename, map_location=torch.device("cpu"),
|
| 172 |
+
weights_only=True)
|
| 173 |
+
|
| 174 |
+
self.load_state_dict(state_dict)
|
| 175 |
+
|
| 176 |
+
def load_state_dict(self, state_dict):
|
| 177 |
+
"""
|
| 178 |
+
Load the state dictionary into the model.
|
| 179 |
+
|
| 180 |
+
This method overrides the default `load_state_dict` to handle Cellpose's custom
|
| 181 |
+
loading mechanism and ensures compatibility with BioImage.IO Core.
|
| 182 |
+
|
| 183 |
+
Args:
|
| 184 |
+
state_dict (Mapping[str, Any]): A state dictionary to load into the model
|
| 185 |
+
"""
|
| 186 |
+
if state_dict["output.2.weight"].shape[0] != self.nout:
|
| 187 |
+
for name in self.state_dict():
|
| 188 |
+
if "output" not in name:
|
| 189 |
+
self.state_dict()[name].copy_(state_dict[name])
|
| 190 |
+
else:
|
| 191 |
+
super().load_state_dict(
|
| 192 |
+
{name: param for name, param in state_dict.items()},
|
| 193 |
+
strict=False)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
|