diff --git a/README.md b/README.md index 05860c677d3641bfee975b950e63b4f9dae3ee8e..4520d21f4e3c9a52a37dec3d0c12f094cd324d65 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,12 @@ --- title: MicroscopyMatching -emoji: 🐢 -colorFrom: green -colorTo: pink +emoji: 🚀 +colorFrom: gray +colorTo: red sdk: gradio -sdk_version: 6.3.0 +sdk_version: 5.49.1 app_file: app.py +python_version: 3.11 pinned: false --- diff --git a/_utils/attn_utils.py b/_utils/attn_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9718ef180217c325de50192851b922bb8cc063b2 --- /dev/null +++ b/_utils/attn_utils.py @@ -0,0 +1,592 @@ +import abc + +import cv2 +import numpy as np +import torch +from IPython.display import display +from PIL import Image +from typing import Union, Tuple, List +from einops import rearrange, repeat +import math +from torch import nn, einsum +from inspect import isfunction +from diffusers.utils import logging +try: + from diffusers.models.unet_2d_condition import UNet2DConditionOutput +except: + from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput + +try: + from diffusers.models.cross_attention import CrossAttention +except: + from diffusers.models.attention_processor import Attention as CrossAttention + +MAX_NUM_WORDS = 77 +LOW_RESOURCE = False + +class CountingCrossAttnProcessor1: + + def __init__(self, attnstore, place_in_unet): + super().__init__() + self.attnstore = attnstore + self.place_in_unet = place_in_unet + + def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, dim = hidden_states.shape + h = attn_layer.heads + q = attn_layer.to_q(hidden_states) + is_cross = encoder_hidden_states is not None + context = encoder_hidden_states if is_cross else hidden_states + k = attn_layer.to_k(context) + v = attn_layer.to_v(context) + # q = attn_layer.reshape_heads_to_batch_dim(q) + # k = attn_layer.reshape_heads_to_batch_dim(k) + # v = attn_layer.reshape_heads_to_batch_dim(v) + # q = attn_layer.head_to_batch_dim(q) + # k = attn_layer.head_to_batch_dim(k) + # v = attn_layer.head_to_batch_dim(v) + q = self.head_to_batch_dim(q, h) + k = self.head_to_batch_dim(k, h) + v = self.head_to_batch_dim(v, h) + + sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale + + if attention_mask is not None: + attention_mask = attention_mask.reshape(batch_size, -1) + max_neg_value = -torch.finfo(sim.dtype).max + attention_mask = attention_mask[:, None, :].repeat(h, 1, 1) + sim.masked_fill_(~attention_mask, max_neg_value) + + # attention, what we cannot get enough of + attn_ = sim.softmax(dim=-1).clone() + # softmax = nn.Softmax(dim=-1) + # attn_ = softmax(sim) + self.attnstore(attn_, is_cross, self.place_in_unet) + out = torch.einsum("b i j, b j d -> b i d", attn_, v) + # out = attn_layer.batch_to_head_dim(out) + out = self.batch_to_head_dim(out, h) + + if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList: + to_out = attn_layer.to_out[0] + else: + to_out = attn_layer.to_out + + out = to_out(out) + return out + + def batch_to_head_dim(self, tensor, head_size): + # head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor, head_size, out_dim=3): + # head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + + return tensor + + +def register_attention_control(model, controller): + + attn_procs = {} + cross_att_count = 0 + for name in model.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.unet.config.block_out_channels[-1] + place_in_unet = "mid" + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] + place_in_unet = "up" + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.unet.config.block_out_channels[block_id] + place_in_unet = "down" + else: + continue + + cross_att_count += 1 + # attn_procs[name] = AttendExciteCrossAttnProcessor( + # attnstore=controller, place_in_unet=place_in_unet + # ) + attn_procs[name] = CountingCrossAttnProcessor1( + attnstore=controller, place_in_unet=place_in_unet + ) + + model.unet.set_attn_processor(attn_procs) + controller.num_att_layers = cross_att_count + +def register_hier_output(model): + self = model.unet + from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None, + attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None, + mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True): + + out_list = [] + + + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) # 1, 320, 64, 64 + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_block_additional_residuals) > 0 + and sample.shape == down_block_additional_residuals[0].shape + ): + sample += down_block_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + # if i in [1, 4, 7]: + out_list.append(sample) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample), out_list + + self.forward = forward + + +class AttentionControl(abc.ABC): + + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + @property + def num_uncond_att_layers(self): + return 0 + + @abc.abstractmethod + def forward(self, attn, is_cross: bool, place_in_unet: str): + raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + # self.forward(attn, is_cross, place_in_unet) + if LOW_RESOURCE: + attn = self.forward(attn, is_cross, place_in_unet) + else: + h = attn.shape[0] + attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + return attn + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def __init__(self): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + +class EmptyControl(AttentionControl): + + def forward(self, attn, is_cross: bool, place_in_unet: str): + return attn + + +class AttentionStore(AttentionControl): + + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], + "down_self": [], "mid_self": [], "up_self": []} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead + self.step_store[key].append(attn) + return attn + + def between_steps(self): + self.attention_store = self.step_store + if self.save_global_store: + with torch.no_grad(): + if len(self.global_store) == 0: + self.global_store = self.step_store + else: + for key in self.global_store: + for i in range(len(self.global_store[key])): + self.global_store[key][i] += self.step_store[key][i].detach() + self.step_store = self.get_empty_store() + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = self.attention_store + return average_attention + + def get_average_global_attention(self): + average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in + self.attention_store} + return average_attention + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_store() + self.attention_store = {} + self.global_store = {} + + def __init__(self, max_size=32, save_global_store=False): + ''' + Initialize an empty AttentionStore + :param step_index: used to visualize only a specific step in the diffusion process + ''' + super(AttentionStore, self).__init__() + self.save_global_store = save_global_store + self.max_size = max_size + self.step_store = self.get_empty_store() + self.attention_store = {} + self.global_store = {} + self.curr_step_index = 0 + +def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): + out = [] + attention_maps = attention_store.get_average_attention() + num_pixels = res ** 2 + for location in from_where: + for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + out = torch.cat(out, dim=0) + out = out.sum(0) / out.shape[0] + return out + + +def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0): + tokens = tokenizer.encode(prompts[select]) + decoder = tokenizer.decode + attention_maps = aggregate_attention(attention_store, res, from_where, True, select) + images = [] + for i in range(len(tokens)): + image = attention_maps[:, :, i] + image = 255 * image / image.max() + image = image.unsqueeze(-1).expand(*image.shape, 3) + image = image.numpy().astype(np.uint8) + image = np.array(Image.fromarray(image).resize((256, 256))) + image = text_under_image(image, decoder(int(tokens[i]))) + images.append(image) + view_images(np.stack(images, axis=0)) + + +def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], + max_com=10, select: int = 0): + attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2)) + u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) + images = [] + for i in range(max_com): + image = vh[i].reshape(res, res) + image = image - image.min() + image = 255 * image / image.max() + image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) + image = Image.fromarray(image).resize((256, 256)) + image = np.array(image) + images.append(image) + view_images(np.concatenate(images, axis=1)) + +def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): + h, w, c = image.shape + offset = int(h * .2) + img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 + font = cv2.FONT_HERSHEY_SIMPLEX + # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) + img[:h] = image + textsize = cv2.getTextSize(text, font, 1, 2)[0] + text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 + cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) + return img + + +def view_images(images, num_rows=1, offset_ratio=0.02): + if type(images) is list: + num_empty = len(images) % num_rows + elif images.ndim == 4: + num_empty = images.shape[0] % num_rows + else: + images = [images] + num_empty = 0 + + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty + num_items = len(images) + + h, w, c = images[0].shape + offset = int(h * offset_ratio) + num_cols = num_items // num_rows + image_ = np.ones((h * num_rows + offset * (num_rows - 1), + w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 + for i in range(num_rows): + for j in range(num_cols): + image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ + i * num_cols + j] + + pil_img = Image.fromarray(image_) + display(pil_img) + +def self_cross_attn(self_attn, cross_attn): + res = self_attn.shape[0] + assert res == cross_attn.shape[0] + # cross attn [res, res] -> [res*res] + cross_attn_ = cross_attn.reshape([res*res]) + # self_attn [res, res, res*res] + self_cross_attn = cross_attn_ * self_attn + self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0) + return self_cross_attn \ No newline at end of file diff --git a/_utils/attn_utils_new.py b/_utils/attn_utils_new.py new file mode 100644 index 0000000000000000000000000000000000000000..76855d6f4a00ecb4abd2fb71e575615d56d34de6 --- /dev/null +++ b/_utils/attn_utils_new.py @@ -0,0 +1,610 @@ +import abc + +import cv2 +import numpy as np +import torch +from IPython.display import display +from PIL import Image +from typing import Union, Tuple, List +from einops import rearrange, repeat +import math +from torch import nn, einsum +from inspect import isfunction +from diffusers.utils import logging +try: + from diffusers.models.unet_2d_condition import UNet2DConditionOutput +except: + from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput +try: + from diffusers.models.cross_attention import CrossAttention +except: + from diffusers.models.attention_processor import Attention as CrossAttention +from typing import Any, Dict, List, Optional, Tuple, Union +MAX_NUM_WORDS = 77 +LOW_RESOURCE = False + +class CountingCrossAttnProcessor1: + + def __init__(self, attnstore, place_in_unet): + super().__init__() + self.attnstore = attnstore + self.place_in_unet = place_in_unet + + def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, dim = hidden_states.shape + h = attn_layer.heads + q = attn_layer.to_q(hidden_states) + is_cross = encoder_hidden_states is not None + context = encoder_hidden_states if is_cross else hidden_states + k = attn_layer.to_k(context) + v = attn_layer.to_v(context) + # q = attn_layer.reshape_heads_to_batch_dim(q) + # k = attn_layer.reshape_heads_to_batch_dim(k) + # v = attn_layer.reshape_heads_to_batch_dim(v) + # q = attn_layer.head_to_batch_dim(q) + # k = attn_layer.head_to_batch_dim(k) + # v = attn_layer.head_to_batch_dim(v) + q = self.head_to_batch_dim(q, h) + k = self.head_to_batch_dim(k, h) + v = self.head_to_batch_dim(v, h) + + sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale + + if attention_mask is not None: + attention_mask = attention_mask.reshape(batch_size, -1) + max_neg_value = -torch.finfo(sim.dtype).max + attention_mask = attention_mask[:, None, :].repeat(h, 1, 1) + sim.masked_fill_(~attention_mask, max_neg_value) + + # attention, what we cannot get enough of + attn_ = sim.softmax(dim=-1).clone() + # softmax = nn.Softmax(dim=-1) + # attn_ = softmax(sim) + self.attnstore(attn_, is_cross, self.place_in_unet) + out = torch.einsum("b i j, b j d -> b i d", attn_, v) + # out = attn_layer.batch_to_head_dim(out) + out = self.batch_to_head_dim(out, h) + + if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList: + to_out = attn_layer.to_out[0] + else: + to_out = attn_layer.to_out + + out = to_out(out) + return out + + def batch_to_head_dim(self, tensor, head_size): + # head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def head_to_batch_dim(self, tensor, head_size, out_dim=3): + # head_size = self.heads + batch_size, seq_len, dim = tensor.shape + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3) + + if out_dim == 3: + tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size) + + return tensor + + +def register_attention_control(model, controller): + + attn_procs = {} + cross_att_count = 0 + for name in model.unet.attn_processors.keys(): + cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim + if name.startswith("mid_block"): + hidden_size = model.unet.config.block_out_channels[-1] + place_in_unet = "mid" + elif name.startswith("up_blocks"): + block_id = int(name[len("up_blocks.")]) + hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id] + place_in_unet = "up" + elif name.startswith("down_blocks"): + block_id = int(name[len("down_blocks.")]) + hidden_size = model.unet.config.block_out_channels[block_id] + place_in_unet = "down" + else: + continue + + cross_att_count += 1 + # attn_procs[name] = AttendExciteCrossAttnProcessor( + # attnstore=controller, place_in_unet=place_in_unet + # ) + attn_procs[name] = CountingCrossAttnProcessor1( + attnstore=controller, place_in_unet=place_in_unet + ) + + model.unet.set_attn_processor(attn_procs) + controller.num_att_layers = cross_att_count + +def register_hier_output(model): + self = model.unet + logger = logging.get_logger(__name__) # pylint: disable=invalid-name + def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None, + attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None, + mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True): + + out_list = [] + + + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + if attention_mask is not None: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + if encoder_attention_mask is not None: + encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + timesteps = timestep + if not torch.is_tensor(timesteps): + # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + t_emb = t_emb.to(dtype=sample.dtype) + + emb = self.time_embedding(t_emb, timestep_cond) + aug_emb = None + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + # `Timesteps` does not contain any weights and will always return f32 tensors + # there might be better ways to encapsulate this. + class_labels = class_labels.to(dtype=sample.dtype) + + class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype) + + if self.config.class_embeddings_concat: + emb = torch.cat([emb, class_emb], dim=-1) + else: + emb = emb + class_emb + + if self.config.addition_embed_type == "text": + aug_emb = self.add_embedding(encoder_hidden_states) + elif self.config.addition_embed_type == "text_image": + # Kandinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + + image_embs = added_cond_kwargs.get("image_embeds") + text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states) + aug_emb = self.add_embedding(text_embs, image_embs) + elif self.config.addition_embed_type == "text_time": + # SDXL - style + if "text_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`" + ) + text_embeds = added_cond_kwargs.get("text_embeds") + if "time_ids" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`" + ) + time_ids = added_cond_kwargs.get("time_ids") + time_embeds = self.add_time_proj(time_ids.flatten()) + time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) + + add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) + add_embeds = add_embeds.to(emb.dtype) + aug_emb = self.add_embedding(add_embeds) + elif self.config.addition_embed_type == "image": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + aug_emb = self.add_embedding(image_embs) + elif self.config.addition_embed_type == "image_hint": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`" + ) + image_embs = added_cond_kwargs.get("image_embeds") + hint = added_cond_kwargs.get("hint") + aug_emb, hint = self.add_embedding(image_embs, hint) + sample = torch.cat([sample, hint], dim=1) + + emb = emb + aug_emb if aug_emb is not None else emb + + if self.time_embed_act is not None: + emb = self.time_embed_act(emb) + + if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj": + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj": + # Kadinsky 2.1 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds) + elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj": + # Kandinsky 2.2 - style + if "image_embeds" not in added_cond_kwargs: + raise ValueError( + f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`" + ) + image_embeds = added_cond_kwargs.get("image_embeds") + encoder_hidden_states = self.encoder_hid_proj(image_embeds) + # 2. pre-process + sample = self.conv_in(sample) # 1, 320, 64, 64 + + # 2.5 GLIGEN position net + if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None: + cross_attention_kwargs = cross_attention_kwargs.copy() + gligen_args = cross_attention_kwargs.pop("gligen") + cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)} + + # 3. down + lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0 + + is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None + is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None + + down_block_res_samples = (sample,) + + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + # For t2i-adapter CrossAttnDownBlock2D + additional_residuals = {} + if is_adapter and len(down_block_additional_residuals) > 0: + additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0) + + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + **additional_residuals, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale) + + if is_adapter and len(down_block_additional_residuals) > 0: + sample += down_block_additional_residuals.pop(0) + + down_block_res_samples += res_samples + + if is_controlnet: + new_down_block_res_samples = () + + for down_block_res_sample, down_block_additional_residual in zip( + down_block_res_samples, down_block_additional_residuals + ): + down_block_res_sample = down_block_res_sample + down_block_additional_residual + new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,) + + down_block_res_samples = new_down_block_res_samples + + # 4. mid + if self.mid_block is not None: + sample = self.mid_block( + sample, + emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + cross_attention_kwargs=cross_attention_kwargs, + encoder_attention_mask=encoder_attention_mask, + ) + # To support T2I-Adapter-XL + if ( + is_adapter + and len(down_block_additional_residuals) > 0 + and sample.shape == down_block_additional_residuals[0].shape + ): + sample += down_block_additional_residuals.pop(0) + + if is_controlnet: + sample = sample + mid_block_additional_residual + + # 5. up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + cross_attention_kwargs=cross_attention_kwargs, + upsample_size=upsample_size, + attention_mask=attention_mask, + encoder_attention_mask=encoder_attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + upsample_size=upsample_size, + scale=lora_scale, + ) + + out_list.append(sample) + + # 6. post-process + if self.conv_norm_out: + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + + if not return_dict: + return (sample,) + + return UNet2DConditionOutput(sample=sample), out_list + + self.forward = forward + + + + + + + +class AttentionControl(abc.ABC): + + def step_callback(self, x_t): + return x_t + + def between_steps(self): + return + + @property + def num_uncond_att_layers(self): + return 0 + + @abc.abstractmethod + def forward(self, attn, is_cross: bool, place_in_unet: str): + raise NotImplementedError + + def __call__(self, attn, is_cross: bool, place_in_unet: str): + if self.cur_att_layer >= self.num_uncond_att_layers: + # self.forward(attn, is_cross, place_in_unet) + if LOW_RESOURCE: + attn = self.forward(attn, is_cross, place_in_unet) + else: + h = attn.shape[0] + attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet) + self.cur_att_layer += 1 + if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers: + self.cur_att_layer = 0 + self.cur_step += 1 + self.between_steps() + return attn + + def reset(self): + self.cur_step = 0 + self.cur_att_layer = 0 + + def __init__(self): + self.cur_step = 0 + self.num_att_layers = -1 + self.cur_att_layer = 0 + + +class EmptyControl(AttentionControl): + + def forward(self, attn, is_cross: bool, place_in_unet: str): + return attn + + +class AttentionStore(AttentionControl): + + @staticmethod + def get_empty_store(): + return {"down_cross": [], "mid_cross": [], "up_cross": [], + "down_self": [], "mid_self": [], "up_self": []} + + def forward(self, attn, is_cross: bool, place_in_unet: str): + key = f"{place_in_unet}_{'cross' if is_cross else 'self'}" + if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead + self.step_store[key].append(attn) + return attn + + def between_steps(self): + self.attention_store = self.step_store + if self.save_global_store: + with torch.no_grad(): + if len(self.global_store) == 0: + self.global_store = self.step_store + else: + for key in self.global_store: + for i in range(len(self.global_store[key])): + self.global_store[key][i] += self.step_store[key][i].detach() + self.step_store = self.get_empty_store() + self.step_store = self.get_empty_store() + + def get_average_attention(self): + average_attention = self.attention_store + return average_attention + + def get_average_global_attention(self): + average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in + self.attention_store} + return average_attention + + def reset(self): + super(AttentionStore, self).reset() + self.step_store = self.get_empty_store() + self.attention_store = {} + self.global_store = {} + + def __init__(self, max_size=32, save_global_store=False): + ''' + Initialize an empty AttentionStore + :param step_index: used to visualize only a specific step in the diffusion process + ''' + super(AttentionStore, self).__init__() + self.save_global_store = save_global_store + self.max_size = max_size + self.step_store = self.get_empty_store() + self.attention_store = {} + self.global_store = {} + self.curr_step_index = 0 + +def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): + out = [] + attention_maps = attention_store.get_average_attention() + num_pixels = res ** 2 + for location in from_where: + for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + out = torch.cat(out, dim=0) + out = out.sum(0) / out.shape[0] + return out + +def aggregate_attention1(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int): + out = [] + attention_maps = attention_store.get_average_attention() + num_pixels = res ** 2 + for location in from_where: + for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]: + if item.shape[1] == num_pixels: + cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select] + out.append(cross_maps) + # out = torch.cat(out, dim=0) + # out = out.sum(0) / out.shape[0] + out = out[1] + out = out.sum(0) / out.shape[0] + return out + + +def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0): + tokens = tokenizer.encode(prompts[select]) + decoder = tokenizer.decode + attention_maps = aggregate_attention(attention_store, res, from_where, True, select) + images = [] + for i in range(len(tokens)): + image = attention_maps[:, :, i] + image = 255 * image / image.max() + image = image.unsqueeze(-1).expand(*image.shape, 3) + image = image.numpy().astype(np.uint8) + image = np.array(Image.fromarray(image).resize((256, 256))) + image = text_under_image(image, decoder(int(tokens[i]))) + images.append(image) + view_images(np.stack(images, axis=0)) + + +def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str], + max_com=10, select: int = 0): + attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2)) + u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True)) + images = [] + for i in range(max_com): + image = vh[i].reshape(res, res) + image = image - image.min() + image = 255 * image / image.max() + image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8) + image = Image.fromarray(image).resize((256, 256)) + image = np.array(image) + images.append(image) + view_images(np.concatenate(images, axis=1)) + +def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)): + h, w, c = image.shape + offset = int(h * .2) + img = np.ones((h + offset, w, c), dtype=np.uint8) * 255 + font = cv2.FONT_HERSHEY_SIMPLEX + # font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size) + img[:h] = image + textsize = cv2.getTextSize(text, font, 1, 2)[0] + text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2 + cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2) + return img + + +def view_images(images, num_rows=1, offset_ratio=0.02): + if type(images) is list: + num_empty = len(images) % num_rows + elif images.ndim == 4: + num_empty = images.shape[0] % num_rows + else: + images = [images] + num_empty = 0 + + empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255 + images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty + num_items = len(images) + + h, w, c = images[0].shape + offset = int(h * offset_ratio) + num_cols = num_items // num_rows + image_ = np.ones((h * num_rows + offset * (num_rows - 1), + w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255 + for i in range(num_rows): + for j in range(num_cols): + image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[ + i * num_cols + j] + + pil_img = Image.fromarray(image_) + display(pil_img) + +def self_cross_attn(self_attn, cross_attn): + cross_attn = cross_attn.squeeze() + res = self_attn.shape[0] + assert res == cross_attn.shape[-1] + # cross attn [res, res] -> [res*res] + cross_attn_ = cross_attn.reshape([res*res]) + # self_attn [res, res, res*res] + self_cross_attn = cross_attn_ * self_attn + self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0) + return self_cross_attn \ No newline at end of file diff --git a/_utils/config.yaml b/_utils/config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..eb1a68c08d286d5d9b01d69f33c10500c1f7b9f9 --- /dev/null +++ b/_utils/config.yaml @@ -0,0 +1,15 @@ +attn_dist_mode: v0 +attn_positional_bias: rope +attn_positional_bias_n_spatial: 16 +causal_norm: quiet_softmax +coord_dim: 2 +d_model: 320 +dropout: 0.0 +feat_dim: 7 +feat_embed_per_dim: 8 +nhead: 4 +num_decoder_layers: 6 +num_encoder_layers: 6 +pos_embed_per_dim: 32 +spatial_pos_cutoff: 256 +window: 4 diff --git a/_utils/example_config.yaml b/_utils/example_config.yaml new file mode 100644 index 0000000000000000000000000000000000000000..12565d9099805f4a30acc3e114b62e88b2b38b3c --- /dev/null +++ b/_utils/example_config.yaml @@ -0,0 +1,20 @@ +batch_size: 1 +crop_size: +- 256 +- 256 +detection_folders: +- TRA +dropout: 0.01 +example_images: False # Slow +input_train: +- data/ctc/Fluo-N2DL-HeLa/01 +input_val: +- data/ctc/Fluo-N2DL-HeLa/02 +max_tokens: 2048 +name: example +ndim: 2 +num_decoder_layers: 5 +num_encoder_layers: 5 +outdir: runs +distributed: False +window: 4 diff --git a/_utils/load_models.py b/_utils/load_models.py new file mode 100644 index 0000000000000000000000000000000000000000..6f3fbbda1afc585fdecc4070e4ab998ef66ef136 --- /dev/null +++ b/_utils/load_models.py @@ -0,0 +1,16 @@ +from config import RunConfig +import torch +from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline +import torch.nn as nn + +def load_stable_diffusion_model(config: RunConfig): + device = torch.device('cpu') + + if config.sd_2_1: + stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base" + else: + stable_diffusion_version = "CompVis/stable-diffusion-v1-4" + # stable = StableCountingPipeline.from_pretrained(stable_diffusion_version).to(device) + stable = StableDiffusionPipeline.from_pretrained(stable_diffusion_version).to(device) + return stable + diff --git a/_utils/load_track_data.py b/_utils/load_track_data.py new file mode 100644 index 0000000000000000000000000000000000000000..b2f797f4e6483e866a4f493d2e8f63a486d4048e --- /dev/null +++ b/_utils/load_track_data.py @@ -0,0 +1,118 @@ +import os +from glob import glob +from pathlib import Path +from natsort import natsorted +from PIL import Image +import numpy as np +import tifffile +import skimage.io as io +import torchvision.transforms as T +import cv2 +from tqdm import tqdm +from models.tra_post_model.trackastra.utils import normalize_01, normalize +IMG_SIZE = 512 + +def _load_tiffs(folder: Path, dtype=None): + """Load a sequence of tiff files from a folder into a 3D numpy array.""" + images = glob(str(folder / "*.tif")) + test_data = tifffile.imread(images[0]) + if len(test_data.shape) == 3: + turn_gray = True + else: + turn_gray = False + end_frame = len(images) + if not turn_gray: + x = np.stack([ + tifffile.imread(f).astype(dtype) + for f in tqdm( + sorted(folder.glob("*.tif"))[0 : end_frame : 1], + leave=False, + desc=f"Loading [0:{end_frame}]", + ) + ]) + else: + x = [] + for f in tqdm( + sorted(folder.glob("*.tif"))[0 : end_frame : 1], + leave=False, + desc=f"Loading [0:{end_frame}]", + ): + img = tifffile.imread(f).astype(dtype) + if img.ndim == 3: + if img.shape[-1] > 3: + img = img[..., :3] + img = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2]) + x.append(img) + x = np.stack(x) + return x + + +def load_track_images(file_dir): + + # suffix_ = [".png", ".tif", ".tiff", ".jpg"] + def find_tif_dir(root_dir): + """递归查找.tif 文件""" + tif_files = [] + for dirpath, _, filenames in os.walk(root_dir): + if '__MACOSX' in dirpath: + continue + for f in filenames: + if f.lower().endswith('.tif'): + tif_files.append(os.path.join(dirpath, f)) + return tif_files + + tif_dir = find_tif_dir(file_dir) + print(f"Found {len(tif_dir)} tif images in {file_dir}") + print(f"First 5 tif images: {tif_dir[:5]}") + assert len(tif_dir) > 0, f"No tif images found in {file_dir}" + images = natsorted(tif_dir) + imgs = [] + imgs_raw = [] + images_stable = [] + # load images for seg and track + for img_path in tqdm(images, desc="Loading images"): + img = tifffile.imread(img_path) + img_raw = io.imread(img_path) + + if img.dtype == 'uint16': + img = ((img - img.min()) / (img.max() - img.min() + 1e-6) * 255).astype(np.uint8) + img = np.stack([img] * 3, axis=-1) + w, h = img.shape[1], img.shape[0] + else: + img = Image.open(img_path).convert("RGB") + w, h = img.size + + img = T.Compose([ + T.ToTensor(), + T.Resize((IMG_SIZE, IMG_SIZE)), + ])(img) + + image_stable = img - 0.5 + img = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img) + + + imgs.append(img) + imgs_raw.append(img_raw) + images_stable.append(image_stable) + + height = h + width = w + imgs = np.stack(imgs, axis=0) + imgs_raw = np.stack(imgs_raw, axis=0) + images_stable = np.stack(images_stable, axis=0) + + # track data + imgs_ = _load_tiffs(Path(file_dir), dtype=np.float32) + imgs_01 = np.stack([ + normalize_01(_x) for _x in tqdm(imgs_, desc="Normalizing", leave=False) + ]) + imgs_ = np.stack([ + normalize(_x) for _x in tqdm(imgs_, desc="Normalizing", leave=False) + ]) + + return imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width + +if __name__ == "__main__": + file_dir = "data/2D+Time/DIC-C2DH-HeLa/train/DIC-C2DH-HeLa/02" + imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width = load_track_images(file_dir) + print(imgs.shape, imgs_raw.shape, images_stable.shape, imgs_.shape, imgs_01.shape, height, width) \ No newline at end of file diff --git a/_utils/misc_helper.py b/_utils/misc_helper.py new file mode 100644 index 0000000000000000000000000000000000000000..1536889fb1a9a224f70370cd89c119e469ff3dec --- /dev/null +++ b/_utils/misc_helper.py @@ -0,0 +1,37 @@ +import logging +import os +import random +import shutil +from collections.abc import Mapping +from datetime import datetime + +import numpy as np +import torch +import torch.distributed as dist + + +def basicConfig(*args, **kwargs): + return + + +# To prevent duplicate logs, we mask this baseConfig setting +logging.basicConfig = basicConfig + + +def create_logger(name, log_file, level=logging.INFO): + log = logging.getLogger(name) + formatter = logging.Formatter( + "[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s" + ) + fh = logging.FileHandler(log_file) + fh.setFormatter(formatter) + sh = logging.StreamHandler() + sh.setFormatter(formatter) + log.setLevel(level) + log.addHandler(fh) + log.addHandler(sh) + return log + +def get_current_time(): + current_time = datetime.now().strftime("%Y%m%d_%H%M%S") + return current_time diff --git a/_utils/seg_eval.py b/_utils/seg_eval.py new file mode 100644 index 0000000000000000000000000000000000000000..4aa2c4ce515e97aad6e350976c290d8698b288d8 --- /dev/null +++ b/_utils/seg_eval.py @@ -0,0 +1,61 @@ +import torch + + +def iou_torch(inst1, inst2): + inter = torch.logical_and(inst1, inst2).sum().float() + union = torch.logical_or(inst1, inst2).sum().float() + if union == 0: + return torch.tensor(float('nan')) + return inter / union + +def get_instances_torch(mask): + # 返回所有非背景的 instance mask(布尔型) + ids = torch.unique(mask) + return [(mask == i) for i in ids if i != 0] + +def compute_instance_miou(pred_mask, gt_mask): + # pred_mask 和 gt_mask 都是 torch.Tensor, shape [H, W], 整数类型 + pred_instances = get_instances_torch(pred_mask) + gt_instances = get_instances_torch(gt_mask) + + ious = [] + for gt in gt_instances: + best_iou = torch.tensor(0.0).to(pred_mask.device) + for pred in pred_instances: + i = iou_torch(pred, gt) + if i > best_iou: + best_iou = i + ious.append(best_iou) + + # 处理空情况 + if len(ious) == 0: + return torch.tensor(float('nan')) + return torch.nanmean(torch.stack(ious)) + +from torch import Tensor + + +def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): + # Average of Dice coefficient for all batches, or for a single mask + assert input.size() == target.size() + assert input.dim() == 3 or not reduce_batch_first + + sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3) + + inter = 2 * (input * target).sum(dim=sum_dim) + sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim) + sets_sum = torch.where(sets_sum == 0, inter, sets_sum) + + dice = (inter + epsilon) / (sets_sum + epsilon) + return dice.mean() + + +def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6): + # Average of Dice coefficient for all classes + return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon) + + +def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False): + # Dice loss (objective to minimize) between 0 and 1 + fn = multiclass_dice_coeff if multiclass else dice_coeff + return 1 - fn(input, target, reduce_batch_first=True) diff --git a/_utils/track_args.py b/_utils/track_args.py new file mode 100644 index 0000000000000000000000000000000000000000..5bab89fd54465d41227097baef46cd8cb2e2b5c7 --- /dev/null +++ b/_utils/track_args.py @@ -0,0 +1,157 @@ +import configargparse + + +def parse_train_args(): + parser = configargparse.ArgumentParser( + formatter_class=configargparse.ArgumentDefaultsHelpFormatter, + config_file_parser_class=configargparse.YAMLConfigFileParser, + allow_abbrev=False, + ) + parser.add_argument( + "-c", + "--config", + default="_utils/example_config.yaml", + is_config_file=True, + help="config file path", + ) + parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda") + parser.add_argument("-o", "--outdir", type=str, default="runs") + parser.add_argument("--name", type=str, help="Name to append to timestamp") + parser.add_argument("--timestamp", type=bool, default=True) + parser.add_argument( + "-m", + "--model", + type=str, + default="", + help="load this model at start (e.g. to continue training)", + ) + parser.add_argument( + "--ndim", type=int, default=2, help="number of spatial dimensions" + ) + parser.add_argument("-d", "--d_model", type=int, default=256) + parser.add_argument("-w", "--window", type=int, default=10) + parser.add_argument("--epochs", type=int, default=100) + parser.add_argument("--warmup_epochs", type=int, default=10) + parser.add_argument( + "--detection_folders", + type=str, + nargs="+", + default=["TRA"], + help=( + "Subfolders to search for detections. Defaults to `TRA`, which corresponds" + " to using only the GT." + ), + ) + parser.add_argument("--downscale_temporal", type=int, default=1) + parser.add_argument("--downscale_spatial", type=int, default=1) + parser.add_argument("--spatial_pos_cutoff", type=int, default=256) + parser.add_argument("--from_subfolder", action="store_true") + # parser.add_argument("--train_samples", type=int, default=50000) + parser.add_argument("--num_encoder_layers", type=int, default=6) + parser.add_argument("--num_decoder_layers", type=int, default=6) + parser.add_argument("--pos_embed_per_dim", type=int, default=32) + parser.add_argument("--feat_embed_per_dim", type=int, default=8) + parser.add_argument("--dropout", type=float, default=0.00) + parser.add_argument("--num_workers", type=int, default=4) + parser.add_argument("--batch_size", type=int, default=1) + parser.add_argument("--max_tokens", type=int, default=None) + parser.add_argument("--delta_cutoff", type=int, default=2) + parser.add_argument("--lr", type=float, default=1e-4) + parser.add_argument( + "--attn_positional_bias", + type=str, + choices=["rope", "bias", "none"], + default="rope", + ) + parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16) + parser.add_argument("--attn_dist_mode", default="v0") + parser.add_argument("--mixedp", type=bool, default=True) + parser.add_argument("--dry", action="store_true") + parser.add_argument("--profile", action="store_true") + parser.add_argument( + "--features", + type=str, + choices=[ + "none", + "regionprops", + "regionprops2", + "patch", + "patch_regionprops", + "wrfeat", + ], + default="wrfeat", + ) + parser.add_argument( + "--causal_norm", + type=str, + choices=["none", "linear", "softmax", "quiet_softmax"], + default="quiet_softmax", + ) + parser.add_argument("--div_upweight", type=float, default=2) + + parser.add_argument("--augment", type=int, default=3) + parser.add_argument("--tracking_frequency", type=int, default=-1) + + parser.add_argument("--sanity_dist", action="store_true") + parser.add_argument("--preallocate", type=bool, default=False) + parser.add_argument("--only_prechecks", action="store_true") + parser.add_argument( + "--compress", type=bool, default=True, help="compress dataset" + ) + + + parser.add_argument("--seed", type=int, default=None) + parser.add_argument( + "--logger", + type=str, + default="tensorboard", + choices=["tensorboard", "wandb", "none"], + ) + parser.add_argument("--wandb_project", type=str, default="trackastra") + parser.add_argument( + "--crop_size", + type=int, + # required=True, + nargs="+", + default=None, + help="random crop size for augmentation", + ) + parser.add_argument( + "--weight_by_ndivs", + type=bool, + default=True, + help="Oversample windows that contain divisions", + ) + parser.add_argument( + "--weight_by_dataset", + type=bool, + default=False, + help=( + "Inversely weight datasets by number of samples (to counter dataset size" + " imbalance)" + ), + ) + + args, unknown_args = parser.parse_known_args() + + # # Hack to allow for --input_test + # allowed_unknown = ["input_test"] + # if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset( + # set(allowed_unknown) + # ): + # raise ValueError(f"Unknown args: {unknown_args}") + + # pprint(vars(args)) + + # for backward compatibility + # if args.attn_positional_bias == "True": + # args.attn_positional_bias = "bias" + # elif args.attn_positional_bias == "False": + # args.attn_positional_bias = False + + # if args.train_samples == 0: + # raise NotImplementedError( + # "--train_samples must be > 0, full dataset pass not supported." + # ) + + return args \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..bd48fc1bed730a32a45cb830abffd1e4311f1a8a --- /dev/null +++ b/app.py @@ -0,0 +1,1638 @@ +import gradio as gr +from gradio_bbox_annotator import BBoxAnnotator +from PIL import Image +import numpy as np +import torch +import os +import shutil +import time +import json +import uuid +from pathlib import Path +import tempfile +import zipfile +from skimage import measure +from matplotlib import cm +from glob import glob +from natsort import natsorted +from huggingface_hub import HfApi, upload_file +# import spaces + +# ===== 导入三个推理模块 ===== +from inference_seg import load_model as load_seg_model, run as run_seg +from inference_count import load_model as load_count_model, run as run_count +from inference_track import load_model as load_track_model, run as run_track + +HF_TOKEN = os.getenv("HF_TOKEN") +DATASET_REPO = "phoebe777777/celltool_feedback" + + +# ===== 清理缓存目录 ===== +print("===== clearing cache =====") +# cache_path = os.path.expanduser("~/.cache/") +cache_path = os.path.expanduser("~/.cache/huggingface/gradio") +if os.path.exists(cache_path): + try: + shutil.rmtree(cache_path) + # print("✅ Deleted ~/.cache/") + print("✅ Deleted ~/.cache/huggingface/gradio") + except: + pass + +# ===== 全局模型变量 ===== +SEG_MODEL = None +SEG_DEVICE = torch.device("cpu") + +COUNT_MODEL = None +COUNT_DEVICE = torch.device("cpu") + +TRACK_MODEL = None +TRACK_DEVICE = torch.device("cpu") + +def load_all_models(): + """启动时加载所有模型""" + global SEG_MODEL, SEG_DEVICE + global COUNT_MODEL, COUNT_DEVICE + global TRACK_MODEL, TRACK_DEVICE + + print("\n" + "="*60) + print("📦 Loading Segmentation Model") + print("="*60) + SEG_MODEL, SEG_DEVICE = load_seg_model(use_box=False) + + print("\n" + "="*60) + print("📦 Loading Counting Model") + print("="*60) + COUNT_MODEL, COUNT_DEVICE = load_count_model(use_box=False) + + print("\n" + "="*60) + print("📦 Loading Tracking Model") + print("="*60) + TRACK_MODEL, TRACK_DEVICE = load_track_model(use_box=False) + + print("\n" + "="*60) + print("✅ All Models Loaded Successfully") + print("="*60) + +load_all_models() + +# ===== 保存用户反馈 ===== +DATASET_DIR = Path("solver_cache") +DATASET_DIR.mkdir(parents=True, exist_ok=True) + +def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None): + """保存反馈到 Hugging Face Dataset""" + + # 如果没有 token,回退到本地存储 + if not HF_TOKEN: + print("⚠️ No HF_TOKEN found, using local storage") + save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes) + return + + feedback_data = { + "query_id": query_id, + "feedback_type": feedback_type, + "feedback_text": feedback_text, + "image_path": img_path, + "bboxes": str(bboxes), # 转为字符串 + "datetime": time.strftime("%Y-%m-%d %H:%M:%S"), + "timestamp": time.time() + } + + try: + api = HfApi() + + # 创建临时文件 + filename = f"feedback_{query_id}_{int(time.time())}.json" + + with open(filename, 'w', encoding='utf-8') as f: + json.dump(feedback_data, f, indent=2, ensure_ascii=False) + + # 上传到 dataset + api.upload_file( + path_or_fileobj=filename, + path_in_repo=f"data/{filename}", + repo_id=DATASET_REPO, + repo_type="dataset", + token=HF_TOKEN + ) + + # 清理本地文件 + os.remove(filename) + + print(f"✅ Feedback saved to HF Dataset: {DATASET_REPO}") + + except Exception as e: + print(f"⚠️ Failed to save to HF Dataset: {e}") + # 回退到本地存储 + save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes) + + +def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None): + """保存用户反馈到JSON文件""" + feedback_data = { + "query_id": query_id, + "feedback_type": feedback_type, + "feedback_text": feedback_text, + "image": img_path, + "bboxes": bboxes, + "datetime": time.strftime("%Y%m%d_%H%M%S") + } + feedback_file = DATASET_DIR / query_id / "feedback.json" + feedback_file.parent.mkdir(parents=True, exist_ok=True) + + if feedback_file.exists(): + with feedback_file.open("r") as f: + existing = json.load(f) + if not isinstance(existing, list): + existing = [existing] + existing.append(feedback_data) + feedback_data = existing + else: + feedback_data = [feedback_data] + + with feedback_file.open("w") as f: + json.dump(feedback_data, f, indent=4, ensure_ascii=False) + +# ===== 辅助函数 ===== +def parse_first_bbox(bboxes): + """解析第一个边界框""" + if not bboxes: + return None + b = bboxes[0] + if isinstance(b, dict): + x, y = float(b.get("x", 0)), float(b.get("y", 0)) + w, h = float(b.get("width", 0)), float(b.get("height", 0)) + return x, y, x + w, y + h + if isinstance(b, (list, tuple)) and len(b) >= 4: + return float(b[0]), float(b[1]), float(b[2]), float(b[3]) + return None + +def parse_bboxes(bboxes): + """解析所有边界框""" + if not bboxes: + return None + + result = [] + for b in bboxes: + if isinstance(b, dict): + x, y = float(b.get("x", 0)), float(b.get("y", 0)) + w, h = float(b.get("width", 0)), float(b.get("height", 0)) + result.append([x, y, x + w, y + h]) + elif isinstance(b, (list, tuple)) and len(b) >= 4: + result.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])]) + + return result + +def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray: + """将实例掩码转换为彩色图像""" + def hsv_to_rgb(h, s, v): + i = int(h * 6.0) + f = h * 6.0 - i + i = i % 6 + p = v * (1 - s) + q = v * (1 - f * s) + t = v * (1 - (1 - f) * s) + if i == 0: r, g, b = v, t, p + elif i == 1: r, g, b = q, v, p + elif i == 2: r, g, b = p, v, t + elif i == 3: r, g, b = p, q, v + elif i == 4: r, g, b = t, p, v + else: r, g, b = v, p, q + return int(r * 255), int(g * 255), int(b * 255) + + palette = [(0, 0, 0)] + for i in range(1, num_colors): + h = (i % num_colors) / float(num_colors) + palette.append(hsv_to_rgb(h, 1.0, 0.95)) + + palette_arr = np.array(palette, dtype=np.uint8) + color_idx = mask % num_colors + return palette_arr[color_idx] + +# ===== 分割功能 ===== +# @spaces.GPU +def segment_with_choice(use_box_choice, annot_value): + """分割主函数 - 每个实例不同颜色+轮廓""" + if annot_value is None or len(annot_value) < 1: + print("❌ No annotation input") + return None, None + + img_path = annot_value[0] + bboxes = annot_value[1] if len(annot_value) > 1 else [] + + print(f"🖼️ Image path: {img_path}") + box_array = None + if use_box_choice == "Yes" and bboxes: + # box = parse_first_bbox(bboxes) + # if box: + # xmin, ymin, xmax, ymax = map(int, box) + # box_array = [[xmin, ymin, xmax, ymax]] + # print(f"📦 Using bounding box: {box_array}") + box = parse_bboxes(bboxes) + if box: + box_array = box + print(f"📦 Using bounding boxes: {box_array}") + + + # 运行分割模型 + try: + mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE) + print("📏 mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask)) + except Exception as e: + print(f"❌ Inference failed: {str(e)}") + return None, None + + # 保存原始mask为TIF文件 + temp_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tif") + mask_img = Image.fromarray(mask.astype(np.uint16)) + mask_img.save(temp_mask_file.name) + print(f"💾 Original mask saved to: {temp_mask_file.name}") + + # 读取原图 + try: + img = Image.open(img_path) + print("📷 Image mode:", img.mode, "size:", img.size) + except Exception as e: + print(f"❌ Failed to open image: {e}") + return None, None + + try: + img_rgb = img.convert("RGB").resize(mask.shape[::-1], resample=Image.BILINEAR) + img_np = np.array(img_rgb, dtype=np.float32) + if img_np.max() > 1.5: + img_np = img_np / 255.0 + except Exception as e: + print(f"❌ Error in image conversion/resizing: {e}") + return None, None + + mask_np = np.array(mask) + inst_mask = mask_np.astype(np.int32) + unique_ids = np.unique(inst_mask) + num_instances = len(unique_ids[unique_ids != 0]) + print(f"✅ Instance IDs found: {unique_ids}, Total instances: {num_instances}") + + if num_instances == 0: + print("⚠️ No instance found, returning dummy red image") + return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None + + # ==== Color Overlay (每个实例一个颜色) ==== + overlay = img_np.copy() + alpha = 0.5 + # cmap = cm.get_cmap("hsv", num_instances + 1) + + for inst_id in np.unique(inst_mask): + if inst_id == 0: + continue + binary_mask = (inst_mask == inst_id).astype(np.uint8) + # color = np.array(cmap(inst_id / (num_instances + 1))[:3]) # RGB only, ignore alpha + color = get_well_spaced_color(inst_id) + overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color + + # 绘制轮廓 + contours = measure.find_contours(binary_mask, 0.5) + for contour in contours: + contour = contour.astype(np.int32) + # 确保坐标在范围内 + valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1) + valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1) + overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # 黄色轮廓 + + overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) + + return Image.fromarray(overlay), temp_mask_file.name + +# ===== 计数功能 ===== +# @spaces.GPU +def count_cells_handler(use_box_choice, annot_value): + """Counting handler - supports bounding box, returns only density map""" + if annot_value is None or len(annot_value) < 1: + return None, "⚠️ Please provide an image." + + image_path = annot_value[0] + bboxes = annot_value[1] if len(annot_value) > 1 else [] + + print(f"🖼️ Image path: {image_path}") + box_array = None + if use_box_choice == "Yes" and bboxes: + # box = parse_first_bbox(bboxes) + # if box: + # xmin, ymin, xmax, ymax = map(int, box) + # box_array = [[xmin, ymin, xmax, ymax]] + # print(f"📦 Using bounding box: {box_array}") + box = parse_bboxes(bboxes) + if box: + box_array = box + print(f"📦 Using bounding boxes: {box_array}") + + try: + print(f"🔢 Counting - Image: {image_path}") + + result = run_count( + COUNT_MODEL, + image_path, + box=box_array, + device=COUNT_DEVICE, + visualize=True + ) + + if 'error' in result: + return None, f"❌ Counting failed: {result['error']}" + + count = result['count'] + density_map = result['density_map'] + # save density map as temp file + temp_density_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy") + np.save(temp_density_file.name, density_map) + print(f"💾 Density map saved to {temp_density_file.name}") + + + try: + img = Image.open(image_path) + print("📷 Image mode:", img.mode, "size:", img.size) + except Exception as e: + print(f"❌ Failed to open image: {e}") + return None, None + + try: + img_rgb = img.convert("RGB").resize(density_map.shape[::-1], resample=Image.BILINEAR) + img_np = np.array(img_rgb, dtype=np.float32) + img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8) + if img_np.max() > 1.5: + img_np = img_np / 255.0 + except Exception as e: + print(f"❌ Error in image conversion/resizing: {e}") + return None, None + + + # Normalize density map to [0, 1] + density_normalized = density_map.copy() + if density_normalized.max() > 0: + density_normalized = (density_normalized - density_normalized.min()) / (density_normalized.max() - density_normalized.min()) + + # Apply colormap + cmap = cm.get_cmap("jet") + alpha = 0.3 + density_colored = cmap(density_normalized)[:, :, :3] # RGB only, ignore alpha + + # Create overlay + overlay = img_np.copy() + + # Blend only where density is significant (optional: threshold) + threshold = 0.01 # Only overlay where density > 1% of max + significant_mask = density_normalized > threshold + + overlay[significant_mask] = (1 - alpha) * overlay[significant_mask] + alpha * density_colored[significant_mask] + + # Clip and convert to uint8 + overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) + + + + + + result_text = f"✅ Detected {round(count)} objects" + if use_box_choice == "Yes" and box: + result_text += f"\n📦 Using bounding box: {box_array}" + + + print(f"✅ Counting done - Count: {count:.1f}") + + return Image.fromarray(overlay), temp_density_file.name, result_text + + # return density_path, result_text + + except Exception as e: + print(f"❌ Counting error: {e}") + import traceback + traceback.print_exc() + return None, f"❌ Counting failed: {str(e)}" + +# ===== Tracking Functionality ===== +def find_tif_dir(root_dir): + """Recursively find the first directory containing .tif files""" + for dirpath, _, filenames in os.walk(root_dir): + if '__MACOSX' in dirpath: + continue + if any(f.lower().endswith('.tif') for f in filenames): + return dirpath + return None + +def is_valid_tiff(filepath): + """Check if a file is a valid TIFF image""" + try: + with Image.open(filepath) as img: + img.verify() + return True + except Exception as e: + return False + +def find_valid_tif_dir(root_dir): + """Recursively find the first directory containing valid .tif files""" + for dirpath, dirnames, filenames in os.walk(root_dir): + if '__MACOSX' in dirpath: + continue + + potential_tifs = [ + os.path.join(dirpath, f) + for f in filenames + if f.lower().endswith(('.tif', '.tiff')) and not f.startswith('._') + ] + + if not potential_tifs: + continue + + valid_tifs = [f for f in potential_tifs if is_valid_tiff(f)] + + if valid_tifs: + print(f"✅ Found {len(valid_tifs)} valid TIFF files in: {dirpath}") + return dirpath + + return None + +def create_ctc_results_zip(output_dir): + """ + Create a ZIP file with CTC format results + + Parameters: + ----------- + output_dir : str + Directory containing tracking results (res_track.txt, etc.) + + Returns: + -------- + zip_path : str + Path to created ZIP file + """ + # Create temp directory for ZIP + temp_zip_dir = tempfile.mkdtemp() + zip_filename = f"tracking_results_{time.strftime('%Y%m%d_%H%M%S')}.zip" + zip_path = os.path.join(temp_zip_dir, zip_filename) + + print(f"📦 Creating results ZIP: {zip_path}") + + # Create ZIP with all tracking results + with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf: + # Add all files from output directory + for root, dirs, files in os.walk(output_dir): + for file in files: + file_path = os.path.join(root, file) + arcname = os.path.relpath(file_path, output_dir) + zipf.write(file_path, arcname) + print(f" 📄 Added: {arcname}") + + # Add a README with summary + readme_content = f"""Tracking Results Summary + ======================== + + Generated: {time.strftime('%Y-%m-%d %H:%M:%S')} + + Files: + ------ + - res_track.txt: CTC format tracking data + Format: track_id start_frame end_frame parent_id + + - Segmentation masks + + For more information on CTC format: + http://celltrackingchallenge.net/ + """ + zipf.writestr("README.txt", readme_content) + + print(f"✅ ZIP created: {zip_path} ({os.path.getsize(zip_path) / 1024:.1f} KB)") + return zip_path + +# 使用更智能的颜色分配 - 让相邻的ID颜色差异更大 +def get_well_spaced_color(track_id, num_colors=256): + """Generate well-spaced colors, using contrasting colors for adjacent IDs""" + # 使用质数跳跃来分散颜色 + golden_ratio = 0.618033988749895 + hue = (track_id * golden_ratio) % 1.0 + + # 使用高饱和度和明度 + import colorsys + rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95) + return np.array(rgb) + + +def extract_first_frame(tif_dir): + """ + Extract the first frame from a directory of TIF files + + Returns: + -------- + first_frame_path : str + Path to the first TIF frame + """ + tif_files = natsorted(glob(os.path.join(tif_dir, "*.tif")) + + glob(os.path.join(tif_dir, "*.tiff"))) + valid_tif_files = [f for f in tif_files + if not os.path.basename(f).startswith('._') and is_valid_tiff(f)] + + if valid_tif_files: + return valid_tif_files[0] + return None + +def create_tracking_visualization(tif_dir, output_dir, valid_tif_files): + """ + Create an animated GIF/video showing tracked objects with consistent colors + + Parameters: + ----------- + tif_dir : str + Directory containing input TIF frames + output_dir : str + Directory containing tracking results (masks) + valid_tif_files : list + List of valid TIF file paths + + Returns: + -------- + video_path : str + Path to generated visualization (GIF or first frame) + """ + import numpy as np + from matplotlib import colormaps + from skimage import measure + import tifffile + + # Look for tracking mask files in output directory + # Common CTC formats: man_track*.tif, mask*.tif, or numbered masks + mask_files = natsorted(glob(os.path.join(output_dir, "mask*.tif")) + + glob(os.path.join(output_dir, "man_track*.tif")) + + glob(os.path.join(output_dir, "*.tif"))) + + if not mask_files: + print("⚠️ No mask files found in output directory") + # Return first frame as fallback + return valid_tif_files[0] + + print(f"📊 Found {len(mask_files)} mask files") + + # Create color map for consistent track IDs + # Use a colormap with many distinct colors + # try: + # cmap = colormaps.get_cmap("hsv") + # except: + # from matplotlib import cm + # cmap = cm.get_cmap("hsv") + + frames = [] + alpha = 0.3 # Transparency for overlay + + # Process each frame + num_frames = min(len(valid_tif_files), len(mask_files)) + for i in range(num_frames): + try: + # Load original image using tifffile (handles ZSTD compression) + try: + img_np = tifffile.imread(valid_tif_files[i]) + + # Normalize to [0, 1] range based on actual data type and values + if img_np.dtype == np.uint8: + img_np = img_np.astype(np.float32) / 255.0 + elif img_np.dtype == np.uint16: + # Normalize uint16 to [0, 1] using actual min/max + img_min, img_max = img_np.min(), img_np.max() + if img_max > img_min: + img_np = (img_np.astype(np.float32) - img_min) / (img_max - img_min) + else: + img_np = img_np.astype(np.float32) / 65535.0 + else: + # For float or other types, normalize based on actual range + img_np = img_np.astype(np.float32) + img_min, img_max = img_np.min(), img_np.max() + if img_max > img_min: + img_np = (img_np - img_min) / (img_max - img_min) + else: + img_np = np.clip(img_np, 0, 1) + + # Convert to RGB if grayscale + if img_np.ndim == 2: + img_np = np.stack([img_np]*3, axis=-1) + img_np = img_np.astype(np.float32) + if img_np.max() > 1.5: + img_np = img_np / 255.0 + except Exception as e: + print(f"⚠️ Error loading image frame {i}: {e}") + # Fallback to PIL + img = Image.open(valid_tif_files[i]).convert("RGB") + img_np = np.array(img, dtype=np.float32) / 255.0 + + # Load tracking mask using tifffile (handles ZSTD compression) + try: + mask = tifffile.imread(mask_files[i]) + except Exception as e: + print(f"⚠️ Error loading mask frame {i}: {e}") + # Fallback to PIL + mask = np.array(Image.open(mask_files[i])) + + # Resize mask to match image if needed + if mask.shape[:2] != img_np.shape[:2]: + from scipy.ndimage import zoom + zoom_factors = [img_np.shape[0] / mask.shape[0], img_np.shape[1] / mask.shape[1]] + mask = zoom(mask, zoom_factors, order=0).astype(mask.dtype) + + # Create overlay + overlay = img_np.copy() + + # Get unique track IDs (excluding background 0) + track_ids = np.unique(mask) + track_ids = track_ids[track_ids != 0] + + # Color each tracked object + for track_id in track_ids: + # Create binary mask for this track + binary_mask = (mask == track_id) + + # Get consistent color for this track ID + # color = np.array(cmap(int(track_id) % 256)[:3]) + color = get_well_spaced_color(int(track_id)) + + # Blend color onto image + overlay[binary_mask] = (1 - alpha) * overlay[binary_mask] + alpha * color + + # Draw contours (optional, adds yellow boundaries) + try: + contours = measure.find_contours(binary_mask.astype(np.uint8), 0.5) + for contour in contours: + contour = contour.astype(np.int32) + valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1) + valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1) + overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # Yellow contour + except: + pass # Skip contours if they fail + + # Convert to uint8 + overlay_uint8 = np.clip(overlay * 255.0, 0, 255).astype(np.uint8) + frames.append(Image.fromarray(overlay_uint8)) + + if i % 10 == 0 or i == num_frames - 1: + print(f" 📸 Processed frame {i+1}/{num_frames}") + + except Exception as e: + print(f"⚠️ Error processing frame {i}: {e}") + import traceback + traceback.print_exc() + continue + + if not frames: + print("⚠️ No frames were processed successfully") + return valid_tif_files[0] + + # Save as animated GIF + try: + temp_gif = tempfile.NamedTemporaryFile(delete=False, suffix=".gif") + frames[0].save( + temp_gif.name, + save_all=True, + append_images=frames[1:], + duration=200, # 200ms per frame = 5fps + loop=0 + ) + temp_gif.close() # Close the file handle + print(f"✅ Created tracking visualization GIF: {temp_gif.name}") + print(f" Size: {os.path.getsize(temp_gif.name)} bytes, Frames: {len(frames)}") + return temp_gif.name + except Exception as e: + print(f"⚠️ Failed to create GIF: {e}") + import traceback + traceback.print_exc() + # Return first frame as static image fallback + try: + temp_img = tempfile.NamedTemporaryFile(delete=False, suffix=".png") + frames[0].save(temp_img.name) + temp_img.close() + return temp_img.name + except: + return valid_tif_files[0] + +# @spaces.GPU +def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj): + """ + 支持 ZIP 压缩包上传的 Tracking 处理函数 - 支持首帧边界框 + + Parameters: + ----------- + use_box_choice : str + "Yes" or "No" - 是否使用边界框 + first_frame_annot : tuple or None + (image_path, bboxes) from BBoxAnnotator, only used if user annotated first frame + zip_file_obj : File + Uploaded ZIP file containing TIF sequence + """ + if zip_file_obj is None: + return None, "⚠️ 请上传包含视频帧的压缩包 (.zip)", None, None + + temp_dir = None + output_temp_dir = None + + try: + # Parse bounding box if provided + box_array = None + if use_box_choice == "Yes" and first_frame_annot is not None: + if isinstance(first_frame_annot, (list, tuple)) and len(first_frame_annot) > 1: + bboxes = first_frame_annot[1] + if bboxes: + # box = parse_first_bbox(bboxes) + # if box: + # xmin, ymin, xmax, ymax = map(int, box) + # box_array = [[xmin, ymin, xmax, ymax]] + # print(f"📦 Using bounding box: {box_array}") + box = parse_bboxes(bboxes) + if box: + box_array = box + print(f"📦 Using bounding boxes: {box_array}") + + # Extract input ZIP + temp_dir = tempfile.mkdtemp() + print(f"\n📦 Extracting to temporary directory: {temp_dir}") + + with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref: + extracted_count = 0 + skipped_count = 0 + + for member in zip_ref.namelist(): + basename = os.path.basename(member) + + if ('__MACOSX' in member or + basename.startswith('._') or + basename.startswith('.DS_Store') or + member.endswith('/')): + skipped_count += 1 + continue + + try: + zip_ref.extract(member, temp_dir) + extracted_count += 1 + if basename.lower().endswith(('.tif', '.tiff')): + print(f"📄 Extracted TIFF: {basename}") + except Exception as e: + print(f"⚠️ Failed to extract {member}: {e}") + + print(f"\n📊 Extracted: {extracted_count} files, Skipped: {skipped_count} files") + + # Find valid TIFF directory + tif_dir = find_valid_tif_dir(temp_dir) + + if tif_dir is None: + return None, "❌ Did not find valid TIF directory", None, None + + # Validate TIFF files + tif_files = natsorted(glob(os.path.join(tif_dir, "*.tif")) + + glob(os.path.join(tif_dir, "*.tiff"))) + valid_tif_files = [f for f in tif_files + if not os.path.basename(f).startswith('._') and is_valid_tiff(f)] + + if len(valid_tif_files) == 0: + return None, "❌ Did not find valid TIF files", None, None + + print(f"📈 Using {len(valid_tif_files)} TIF files") + + # Store paths for later visualization + first_frame_path = valid_tif_files[0] + + # Create temporary output directory for CTC results + output_temp_dir = tempfile.mkdtemp() + print(f"💾 CTC-format results will be saved to: {output_temp_dir}") + + # Run tracking with optional bounding box + result = run_track( + TRACK_MODEL, + video_dir=tif_dir, + box=box_array, # Pass bounding box if specified + device=TRACK_DEVICE, + output_dir=output_temp_dir + ) + + if 'error' in result: + return None, f"❌ Tracking failed: {result['error']}", None, None + + # Create visualization video of tracked objects + print("\n🎬 Creating tracking visualization...") + try: + tracking_video = create_tracking_visualization( + tif_dir, + output_temp_dir, + valid_tif_files + ) + except Exception as e: + print(f"⚠️ Failed to create visualization: {e}") + import traceback + traceback.print_exc() + # Fallback to first frame if visualization fails + try: + tracking_video = Image.open(first_frame_path) + except: + tracking_video = None + + # Create downloadable ZIP with results + try: + results_zip = create_ctc_results_zip(output_temp_dir) + except Exception as e: + print(f"⚠️ Failed to create ZIP: {e}") + results_zip = None + + bbox_info = "" + if box_array: + bbox_info = f"\n🔲 Using bounding box: [{box_array[0][0]}, {box_array[0][1]}, {box_array[0][2]}, {box_array[0][3]}]" + + result_text = f"""✅ Tracking completed! + + 🖼️ Processed frames: {len(valid_tif_files)}{bbox_info} + + 📥 Click the button below to download CTC-format results + The results include: + - res_track.txt (CTC-format tracking data) + - Other tracking-related files + - README.txt (Results description) + """ + + if use_box_choice == "Yes" and box: + result_text += f"\n📦 Using bounding box: {box_array}" + + print(f"\n✅ Tracking completed") + + # Clean up input temp directory (keep output temp for download) + if temp_dir: + try: + shutil.rmtree(temp_dir) + print(f"🗑️ Cleared input temp directory") + except: + pass + + return results_zip, result_text, gr.update(visible=True), tracking_video + + except zipfile.BadZipFile: + return None, "❌ Not a valid ZIP file", None, None + except Exception as e: + import traceback + traceback.print_exc() + + # Clean up on error + for d in [temp_dir, output_temp_dir]: + if d: + try: + shutil.rmtree(d) + except: + pass + return None, f"❌ Tracking failed: {str(e)}", None, None + + + +# ===== 示例图像 ===== +example_images_seg = [f for f in glob("example_imgs/seg/*")] +# ["example_imgs/seg/003_img.png", "example_imgs/seg/1977_Well_F-5_Field_1.png"] +example_images_cnt = [f for f in glob("example_imgs/cnt/*")] +example_tracking_zips = [f for f in glob("example_imgs/tra/*.zip")] + +# ===== Gradio UI ===== +with gr.Blocks( + title="Microscopy Analysis Suite", + theme=gr.themes.Soft(), + css=""" + .tabs button { + font-size: 18px !important; + font-weight: 600 !important; + padding: 12px 20px !important; + } + .uniform-height { + height: 500px !important; + display: flex !important; + align-items: center !important; + justify-content: center !important; + } + + .uniform-height img, + .uniform-height canvas { + max-height: 500px !important; + object-fit: contain !important; + } + + /* 强制密度图容器和图片高度 */ + #density_map_output { + height: 500px !important; + } + + #density_map_output .image-container { + height: 500px !important; + } + + #density_map_output img { + height: 480px !important; + width: auto !important; + max-width: 90% !important; + object-fit: contain !important; + } + """ + ) as demo: + gr.Markdown( + """ + # 🔬 Microscopy Image Analysis Suite + + Supporting three key tasks: + - 🎨 **Segmentation**: Instance segmentation of microscopic objects + - 🔢 **Counting**: Counting microscopic objects based on density maps + - 🎬 **Tracking**: Tracking microscopic objects in video sequences + """ + ) + + # 全局状态 + current_query_id = gr.State(str(uuid.uuid4())) + user_uploaded_examples = gr.State(example_images_seg.copy()) # 初始化时包含原始示例 + + with gr.Tabs(): + # ===== Tab 1: Segmentation ===== + with gr.Tab("🎨 Segmentation"): + gr.Markdown("## Instance Segmentation of Microscopic Objects") + gr.Markdown( + """ + **Instructions:** + 1. Upload an image or select an example image (supports various formats: .png, .jpg, .tif) + 2. (Optional) Specify a target object with a bounding box and select "Yes", or click "Run Segmentation" directly + 3. Click "Run Segmentation" + 4. View the segmentation results, download the original predicted mask (.tif format); if needed, click "Clear Selection" to choose a new image + + 🤘 Rate and submit feedback to help us improve the model! + """ + ) + + with gr.Row(): + with gr.Column(scale=1): + annotator = BBoxAnnotator( + label="🖼️ Upload Image (Optional: Provide a Bounding Box)", + categories=["cell"], + ) + + # Example Images Gallery + example_gallery = gr.Gallery( + label="📁 Example Image Gallery", + columns=len(example_images_seg), + rows=1, + height=120, + object_fit="cover", + show_download_button=False + ) + + + with gr.Row(): + use_box_radio = gr.Radio( + choices=["Yes", "No"], + value="No", + label="🔲 Specify Bounding Box?" + ) + with gr.Row(): + run_seg_btn = gr.Button("▶️ Run Segmentation", variant="primary", size="lg") + clear_btn = gr.Button("🔄 Clear Selection", variant="secondary") + + # Upload Example Image + image_uploader = gr.Image( + label="➕ Upload New Example Image to Gallery", + type="filepath" + ) + + + with gr.Column(scale=2): + seg_output = gr.Image( + type="pil", + label="📸 Segmentation Result", + elem_classes="uniform-height" + ) + + # Download Original Prediction + download_mask_btn = gr.File( + label="📥 Download Original Prediction (.tif format)", + visible=True, + height=40, + ) + + # Satisfaction Rating + score_slider = gr.Slider( + minimum=1, + maximum=5, + step=1, + value=5, + label="🌟 Satisfaction Rating (1-5)" + ) + + # Feedback Textbox + feedback_box = gr.Textbox( + placeholder="Please enter your feedback...", + lines=2, + label="💬 Feedback" + ) + + # Submit Button + submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary") + + feedback_status = gr.Textbox( + label="✅ Submission Status", + lines=1, + visible=False + ) + + # 绑定事件: 运行分割 + run_seg_btn.click( + fn=segment_with_choice, + inputs=[use_box_radio, annotator], + outputs=[seg_output, download_mask_btn] + ) + + # 清空按钮事件 + clear_btn.click( + fn=lambda: None, + inputs=None, + outputs=annotator + ) + + # 初始化Gallery显示 + demo.load( + fn=lambda: example_images_seg.copy(), + outputs=example_gallery + ) + + # 绑定事件: 上传示例图片 + def add_to_gallery(img_path, current_imgs): + if not img_path: + return current_imgs + try: + if img_path not in current_imgs: + current_imgs.append(img_path) + return current_imgs + except: + return current_imgs + + image_uploader.change( + fn=add_to_gallery, + inputs=[image_uploader, user_uploaded_examples], + outputs=user_uploaded_examples + ).then( + fn=lambda imgs: imgs, + inputs=user_uploaded_examples, + outputs=example_gallery + ) + + # 绑定事件: 点击Gallery加载 + def load_from_gallery(evt: gr.SelectData, all_imgs): + if evt.index is not None and evt.index < len(all_imgs): + return all_imgs[evt.index] + return None + + example_gallery.select( + fn=load_from_gallery, + inputs=user_uploaded_examples, + outputs=annotator + ) + + # 绑定事件: 提交反馈 + def submit_user_feedback(query_id, score, comment, annot_val): + try: + img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None + bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else [] + + # save_feedback( + # query_id=query_id, + # feedback_type=f"score_{int(score)}", + # feedback_text=comment, + # img_path=img_path, + # bboxes=bboxes + # ) + # 使用 HF 存储 + save_feedback_to_hf( + query_id=query_id, + feedback_type=f"score_{int(score)}", + feedback_text=comment, + img_path=img_path, + bboxes=bboxes + ) + return "✅ Feedback submitted, thank you!", gr.update(visible=True) + except Exception as e: + return f"❌ Submission failed: {str(e)}", gr.update(visible=True) + + submit_feedback_btn.click( + fn=submit_user_feedback, + inputs=[current_query_id, score_slider, feedback_box, annotator], + outputs=[feedback_status, feedback_status] + ) + + # ===== Tab 2: Counting ===== + with gr.Tab("🔢 Counting"): + gr.Markdown("## Microscopy Object Counting Analysis") + gr.Markdown( + """ + **Usage Instructions:** + 1. Upload an image or select an example image (supports multiple formats: .png, .jpg, .tif) + 2. (Optional) Specify a target object with a bounding box and select "Yes", or click "Run Counting" directly + 3. Click "Run Counting" + 4. View the density map, download the original prediction (.npy format); if needed, click "Clear Selection" to choose a new image to run + + 🤘 Rate and submit feedback to help us improve the model! + """ + ) + + with gr.Row(): + with gr.Column(scale=1): + count_annotator = BBoxAnnotator( + label="🖼️ Upload Image (Optional: Provide a Bounding Box)", + categories=["cell"], + ) + + # Example gallery with "add" functionality + with gr.Row(): + count_example_gallery = gr.Gallery( + label="📁 Example Image Gallery", + columns=len(example_images_cnt), + rows=1, + object_fit="cover", + height=120, + value=example_images_cnt.copy(), # Initialize with examples + show_download_button=False + ) + + + with gr.Row(): + count_use_box_radio = gr.Radio( + choices=["Yes", "No"], + value="No", + label="🔲 Specify Bounding Box?" + ) + + with gr.Row(): + count_btn = gr.Button("▶️ Run Counting", variant="primary", size="lg") + clear_btn = gr.Button("🔄 Clear Selection", variant="secondary") + + # Add button to upload new examples + with gr.Row(): + count_image_uploader = gr.File( + label="➕ Add Example Image to Gallery", + file_types=["image"], + type="filepath" + ) + + + with gr.Column(scale=2): + count_output = gr.Image( + label="📸 Density Map", + type="filepath", + elem_id="density_map_output" + + ) + count_status = gr.Textbox( + label="📊 Statistics", + lines=2 + ) + download_density_btn = gr.File( + label="📥 Download Original Prediction (.npy format)", + visible=True + ) + + # Satisfaction rating + score_slider = gr.Slider( + minimum=1, + maximum=5, + step=1, + value=5, + label="🌟 Satisfaction Rating (1-5)" + ) + + # Feedback textbox + feedback_box = gr.Textbox( + placeholder="Please enter your feedback...", + lines=2, + label="💬 Feedback" + ) + + # Submit button + submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary") + + feedback_status = gr.Textbox( + label="✅ Submission Status", + lines=1, + visible=False + ) + + # State for managing gallery images + count_user_examples = gr.State(example_images_cnt.copy()) + + # Function to add image to gallery + def add_to_count_gallery(new_img_file, current_imgs): + """Add uploaded image to gallery""" + if new_img_file is None: + return current_imgs, current_imgs + + try: + # Add new image path to list + if new_img_file not in current_imgs: + current_imgs.append(new_img_file) + print(f"✅ Added image to gallery: {new_img_file}") + except Exception as e: + print(f"⚠️ Failed to add image: {e}") + + return current_imgs, current_imgs + + # When user uploads a new image file + count_image_uploader.upload( + fn=add_to_count_gallery, + inputs=[count_image_uploader, count_user_examples], + outputs=[count_user_examples, count_example_gallery] + ) + + # When user selects from gallery, load into annotator + def load_from_count_gallery(evt: gr.SelectData, all_imgs): + """Load selected image from gallery into annotator""" + if evt.index is not None and evt.index < len(all_imgs): + selected_img = all_imgs[evt.index] + print(f"📸 Loading image from gallery: {selected_img}") + return selected_img + return None + + count_example_gallery.select( + fn=load_from_count_gallery, + inputs=count_user_examples, + outputs=count_annotator + ) + + # Run counting + count_btn.click( + fn=count_cells_handler, + inputs=[count_use_box_radio, count_annotator], + outputs=[count_output, download_density_btn, count_status] + ) + + # 清空按钮事件 + clear_btn.click( + fn=lambda: None, + inputs=None, + outputs=count_annotator + ) + + # 绑定事件: 提交反馈 + def submit_user_feedback(query_id, score, comment, annot_val): + try: + img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None + bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else [] + + # save_feedback( + # query_id=query_id, + # feedback_type=f"score_{int(score)}", + # feedback_text=comment, + # img_path=img_path, + # bboxes=bboxes + # ) + # 使用 HF 存储 + save_feedback_to_hf( + query_id=query_id, + feedback_type=f"score_{int(score)}", + feedback_text=comment, + img_path=img_path, + bboxes=bboxes + ) + return "✅ Feedback submitted successfully, thank you!", gr.update(visible=True) + except Exception as e: + return f"❌ Submission failed: {str(e)}", gr.update(visible=True) + + submit_feedback_btn.click( + fn=submit_user_feedback, + inputs=[current_query_id, score_slider, feedback_box, annotator], + outputs=[feedback_status, feedback_status] + ) + + # ===== Tab 3: Tracking ===== + with gr.Tab("🎬 Tracking"): + gr.Markdown("## Microscopy Object Video Tracking - Supports ZIP Upload") + gr.Markdown( + """ + **Instructions:** + 1. Upload a ZIP file or select from the example library. The ZIP should contain a sequence of TIF images named in chronological order (e.g., t000.tif, t001.tif...) + 2. (Optional) Specify a target object with a bounding box on the first frame and select "Yes", or click "Run Tracking" directly + 3. Click "Run Tracking" + 4. Download the CTC format results; if needed, click "Clear Selection" to choose a new ZIP file to run + + 🤘 Rate and submit feedback to help us improve the model! + + """ + ) + + with gr.Row(): + with gr.Column(scale=1): + track_zip_upload = gr.File( + label="📦 Upload Image Sequence in ZIP File", + file_types=[".zip"] + ) + + # First frame annotation for bounding box + track_first_frame_annotator = BBoxAnnotator( + label="🖼️ (Optional) First Frame Bounding Box Annotation", + categories=["cell"], + visible=False, # Hidden initially + ) + + # Example ZIP gallery + track_example_gallery = gr.Gallery( + label="📁 Example Video Gallery (Click to Select)", + columns=10, + rows=1, + height=120, + object_fit="contain", + show_download_button=False + ) + + with gr.Row(): + track_use_box_radio = gr.Radio( + choices=["Yes", "No"], + value="No", + label="🔲 Specify Bounding Box?" + ) + + with gr.Row(): + track_btn = gr.Button("▶️ Run Tracking", variant="primary", size="lg") + clear_btn = gr.Button("🔄 Clear Selection", variant="secondary") + + # Add to gallery button + track_gallery_upload = gr.File( + label="➕ Add ZIP to Example Gallery", + file_types=[".zip"], + type="filepath" + ) + + with gr.Column(scale=2): + track_first_frame_preview = gr.Image( + label="📸 Tracking Visualization", + type="filepath", + # height=400, + elem_classes="uniform-height", + interactive=False + ) + + track_output = gr.Textbox( + label="📊 Tracking Information", + lines=8, + interactive=False + ) + + track_download = gr.File( + label="📥 Download Tracking Results (CTC Format)", + visible=False + ) + + # Satisfaction rating + score_slider = gr.Slider( + minimum=1, + maximum=5, + step=1, + value=5, + label="🌟 Satisfaction Rating (1-5)" + ) + + # Feedback textbox + feedback_box = gr.Textbox( + placeholder="Please enter your feedback...", + lines=2, + label="💬 Feedback" + ) + + # Submit button + submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary") + + feedback_status = gr.Textbox( + label="✅ Submission Status", + lines=1, + visible=False + ) + + # State for tracking examples + track_user_examples = gr.State(example_tracking_zips.copy()) + + # Function to get preview image from ZIP + def get_zip_preview(zip_path): + """Extract first frame from ZIP for gallery preview""" + try: + temp_dir = tempfile.mkdtemp() + with zipfile.ZipFile(zip_path, 'r') as zip_ref: + for member in zip_ref.namelist(): + basename = os.path.basename(member) + if ('__MACOSX' not in member and + not basename.startswith('._') and + basename.lower().endswith(('.tif', '.tiff', '.png', '.jpg'))): + zip_ref.extract(member, temp_dir) + extracted_path = os.path.join(temp_dir, member) + + # Load and normalize for preview + import tifffile + import numpy as np + + img_np = tifffile.imread(extracted_path) + if img_np.dtype == np.uint16: + img_min, img_max = img_np.min(), img_np.max() + if img_max > img_min: + img_np = ((img_np.astype(np.float32) - img_min) / (img_max - img_min) * 255).astype(np.uint8) + + if img_np.ndim == 2: + img_np = np.stack([img_np]*3, axis=-1) + + # Save preview + preview_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png") + Image.fromarray(img_np).save(preview_path.name) + return preview_path.name + except: + pass + return None + + # Initialize gallery with previews + def init_tracking_gallery(): + """Create preview images for ZIP examples""" + previews = [] + for zip_path in example_tracking_zips: + if os.path.exists(zip_path): + preview = get_zip_preview(zip_path) + if preview: + previews.append(preview) + return previews + + # Load gallery on startup + demo.load( + fn=init_tracking_gallery, + outputs=track_example_gallery + ) + + # Add ZIP to gallery + def add_zip_to_gallery(zip_path, current_zips): + if not zip_path: + return current_zips, track_example_gallery + try: + if zip_path not in current_zips: + current_zips.append(zip_path) + print(f"✅ Added ZIP to gallery: {zip_path}") + # Regenerate previews + previews = [] + for zp in current_zips: + preview = get_zip_preview(zp) + if preview: + previews.append(preview) + return current_zips, previews + except Exception as e: + print(f"⚠️ Error: {e}") + return current_zips, [] + + track_gallery_upload.upload( + fn=add_zip_to_gallery, + inputs=[track_gallery_upload, track_user_examples], + outputs=[track_user_examples, track_example_gallery] + ) + + # Select ZIP from gallery + def load_zip_from_gallery(evt: gr.SelectData, all_zips): + if evt.index is not None and evt.index < len(all_zips): + selected_zip = all_zips[evt.index] + print(f"📁 Selected ZIP from gallery: {selected_zip}") + return selected_zip + return None + + track_example_gallery.select( + fn=load_zip_from_gallery, + inputs=track_user_examples, + outputs=track_zip_upload + ) + + # Load first frame when ZIP is uploaded + def load_first_frame_for_annotation(zip_file_obj): + '''Load and normalize first frame from ZIP for annotation''' + if zip_file_obj is None: + return None, gr.update(visible=False) + + import tifffile + import numpy as np + + try: + temp_dir = tempfile.mkdtemp() + with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref: + for member in zip_ref.namelist(): + basename = os.path.basename(member) + if ('__MACOSX' not in member and + not basename.startswith('._') and + basename.lower().endswith(('.tif', '.tiff'))): + zip_ref.extract(member, temp_dir) + + tif_dir = find_valid_tif_dir(temp_dir) + if tif_dir: + first_frame = extract_first_frame(tif_dir) + if first_frame: + # Load and normalize the first frame + try: + img_np = tifffile.imread(first_frame) + + # Normalize to [0, 255] uint8 range for display + if img_np.dtype == np.uint8: + pass # Already uint8 + elif img_np.dtype == np.uint16: + # Normalize uint16 using actual min/max + img_min, img_max = img_np.min(), img_np.max() + if img_max > img_min: + img_np = ((img_np.astype(np.float32) - img_min) / (img_max - img_min) * 255).astype(np.uint8) + else: + img_np = (img_np.astype(np.float32) / 65535.0 * 255).astype(np.uint8) + else: + # Float or other types + img_np = img_np.astype(np.float32) + img_min, img_max = img_np.min(), img_np.max() + if img_max > img_min: + img_np = ((img_np - img_min) / (img_max - img_min) * 255).astype(np.uint8) + else: + img_np = np.clip(img_np * 255, 0, 255).astype(np.uint8) + + # Convert to RGB if grayscale + if img_np.ndim == 2: + img_np = np.stack([img_np]*3, axis=-1) + elif img_np.ndim == 3 and img_np.shape[2] > 3: + img_np = img_np[:, :, :3] + + # Save normalized image to temp file + temp_img = tempfile.NamedTemporaryFile(delete=False, suffix=".png") + Image.fromarray(img_np).save(temp_img.name) + + print(f"✅ Loaded and normalized first frame: {first_frame}") + print(f" Original dtype: {tifffile.imread(first_frame).dtype}") + print(f" Normalized to uint8 RGB for annotation") + + return temp_img.name, gr.update(visible=True) + except Exception as e: + print(f"⚠️ Error normalizing first frame: {e}") + import traceback + traceback.print_exc() + # Fallback to original file + return first_frame, gr.update(visible=True) + except Exception as e: + print(f"⚠️ Error loading first frame: {e}") + import traceback + traceback.print_exc() + return None, gr.update(visible=False) + + # Load first frame when ZIP is uploaded + track_zip_upload.change( + fn=load_first_frame_for_annotation, + inputs=track_zip_upload, + outputs=[track_first_frame_annotator, track_first_frame_annotator] + ) + + # Run tracking + track_btn.click( + fn=track_video_handler, + inputs=[track_use_box_radio, track_first_frame_annotator, track_zip_upload], + outputs=[track_download, track_output, track_download, track_first_frame_preview] + ) + + # 清空按钮事件 + clear_btn.click( + fn=lambda: None, + inputs=None, + outputs=track_first_frame_annotator + ) + + # 绑定事件: 提交反馈 + def submit_user_feedback(query_id, score, comment, annot_val): + try: + img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None + bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else [] + + # save_feedback( + # query_id=query_id, + # feedback_type=f"score_{int(score)}", + # feedback_text=comment, + # img_path=img_path, + # bboxes=bboxes + # ) + # 使用 HF 存储 + save_feedback_to_hf( + query_id=query_id, + feedback_type=f"score_{int(score)}", + feedback_text=comment, + img_path=img_path, + bboxes=bboxes + ) + return "✅ Feedback submitted successfully, thank you!", gr.update(visible=True) + except Exception as e: + return f"❌ Submission failed: {str(e)}", gr.update(visible=True) + + submit_feedback_btn.click( + fn=submit_user_feedback, + inputs=[current_query_id, score_slider, feedback_box, annotator], + outputs=[feedback_status, feedback_status] + ) + + gr.Markdown( + """ + --- + ### 💡 Technical Details + + **MicroscopyMatching** - A general-purpose microscopy image analysis toolkit based on Stable Diffusion + """ + ) + +if __name__ == "__main__": + demo.queue().launch( + server_name="0.0.0.0", + server_port=7860, + share=False, + ssr_mode=False, + show_error=True, + ) diff --git a/config.py b/config.py new file mode 100644 index 0000000000000000000000000000000000000000..9a11d3a51a8e9e073dbad3b975edb9908aff4fb4 --- /dev/null +++ b/config.py @@ -0,0 +1,44 @@ +from dataclasses import dataclass, field +from pathlib import Path +from typing import Dict, List + + +@dataclass +class RunConfig: + # Guiding text prompt + prompt: str = "" + # Whether to use Stable Diffusion v2.1 + sd_2_1: bool = False + # Which token indices to alter with attend-and-excite + token_indices: List[int] = field(default_factory=lambda: [2,5]) + # Which random seeds to use when generating + seeds: List[int] = field(default_factory=lambda: [42]) + # Path to save all outputs to + output_path: Path = Path('./outputs') + # Number of denoising steps + n_inference_steps: int = 50 + # Text guidance scale + guidance_scale: float = 7.5 + # Number of denoising steps to apply attend-and-excite + max_iter_to_alter: int = 25 + # Resolution of UNet to compute attention maps over + attention_res: int = 16 + # Whether to run standard SD or attend-and-excite + run_standard_sd: bool = False + # Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in + thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8}) + # Scale factor for updating the denoised latent z_t + scale_factor: int = 20 + # Start and end values used for scaling the scale factor - decays linearly with the denoising timestep + scale_range: tuple = field(default_factory=lambda: (1.0, 0.5)) + # Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token + smooth_attentions: bool = True + # Standard deviation for the Gaussian smoothing + sigma: float = 0.5 + # Kernel size for the Gaussian smoothing + kernel_size: int = 3 + # Whether to save cross attention maps for the final results + save_cross_attention_maps: bool = False + + def __post_init__(self): + self.output_path.mkdir(exist_ok=True, parents=True) diff --git a/counting.py b/counting.py new file mode 100644 index 0000000000000000000000000000000000000000..c54e0842aabebd37fdd661e59274bfcc1bc485c0 --- /dev/null +++ b/counting.py @@ -0,0 +1,340 @@ +# stable diffusion x loca +import os +import pprint +from typing import Any, List, Optional +import argparse +from huggingface_hub import hf_hub_download +import pyrallis +from pytorch_lightning.utilities.types import STEP_OUTPUT +import torch +import os +from PIL import Image +import numpy as np +from config import RunConfig +from _utils import attn_utils_new as attn_utils +from _utils.attn_utils import AttentionStore +from _utils.misc_helper import * +import torch.nn.functional as F +import matplotlib.pyplot as plt +import cv2 +import warnings +from pytorch_lightning.callbacks import ModelCheckpoint +warnings.filterwarnings("ignore", category=UserWarning) +import pytorch_lightning as pl +from _utils.load_models import load_stable_diffusion_model +from models.model import Counting_with_SD_features_loca as Counting +from pytorch_lightning.loggers import WandbLogger +from models.enc_model.loca_args import get_argparser as loca_get_argparser +from models.enc_model.loca import build_model as build_loca_model +import time +import torchvision.transforms as T +import skimage.io as io + +SCALE = 1 + + +class CountingModule(pl.LightningModule): + def __init__(self, use_box=True): + super().__init__() + self.use_box = use_box + self.config = RunConfig() # config for stable diffusion + self.initialize_model() + + + def initialize_model(self): + + # load loca model + loca_args = loca_get_argparser().parse_args() + self.loca_model = build_loca_model(loca_args) + # weights = torch.load("ckpt/loca_few_shot.pt")["model"] + # weights = {k.replace("module","") : v for k, v in weights.items()} + # self.loca_model.load_state_dict(weights, strict=False) + # del weights + + self.counting_adapter = Counting(scale_factor=SCALE) + # if os.path.isfile(self.args.adapter_weight): + # adapter_weight = torch.load(self.args.adapter_weight,map_location=torch.device('cpu')) + # self.counting_adapter.load_state_dict(adapter_weight, strict=False) + + ### load stable diffusion and its controller + self.stable = load_stable_diffusion_model(config=self.config) + self.noise_scheduler = self.stable.scheduler + self.controller = AttentionStore(max_size=64) + attn_utils.register_attention_control(self.stable, self.controller) + attn_utils.register_hier_output(self.stable) + + ##### initialize token_emb ##### + placeholder_token = "" + self.task_token = "repetitive objects" + # Add the placeholder token in tokenizer + num_added_tokens = self.stable.tokenizer.add_tokens(placeholder_token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + try: + task_embed_from_pretrain = hf_hub_download( + repo_id="phoebe777777/111", + filename="task_embed.pth", + token=None, + force_download=False + ) + placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token) + self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) + + token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data + token_embeds[placeholder_token_id] = task_embed_from_pretrain + except: + initializer_token = "count" + token_ids = self.stable.tokenizer.encode(initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + + initializer_token_id = token_ids[0] + placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token) + + self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) + + token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data + token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] + + # others + self.placeholder_token = placeholder_token + self.placeholder_token_id = placeholder_token_id + + + def move_to_device(self, device): + self.stable.to(device) + if self.loca_model is not None and self.counting_adapter is not None: + self.loca_model.to(device) + self.counting_adapter.to(device) + self.to(device) + + def forward(self, data_path, box=None): + filename = data_path.split("/")[-1] + img = Image.open(data_path).convert("RGB") + width, height = img.size + input_image = T.Compose([T.ToTensor(), T.Resize((512, 512))])(img) + input_image_stable = input_image - 0.5 + input_image = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_image) + if box is not None: + boxes = torch.tensor(box) / torch.tensor([width, height, width, height]) * 512 # xyxy, normalized + assert self.use_box == True + else: + boxes = torch.tensor([[100,100,130,130], [200,200,250,250]], dtype=torch.float32) # dummy box + assert self.use_box == False + + # move to device + input_image = input_image.unsqueeze(0).to(self.device) + boxes = boxes.unsqueeze(0).to(self.device) + input_image_stable = input_image_stable.unsqueeze(0).to(self.device) + + + + latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() + latents = latents * 0.18215 + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + timesteps = torch.tensor([20], device=latents.device).long() + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + input_ids_ = self.stable.tokenizer( + self.placeholder_token + " repetitive objects", + # "object", + padding="max_length", + truncation=True, + max_length=self.stable.tokenizer.model_max_length, + return_tensors="pt", + ) + input_ids = input_ids_["input_ids"].to(self.device) + attention_mask = input_ids_["attention_mask"].to(self.device) + encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] + + input_image = input_image.to(self.device) + boxes = boxes.to(self.device) + + + task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) + if self.use_box: + loca_out = self.loca_model.forward_before_reg(input_image, boxes) + loca_feature_bf_regression = loca_out["feature_bf_regression"] + adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768] + if task_loc_idx.shape[0] == 0: + encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位 + else: + encoder_hidden_states[0,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位 + + # Predict the noise residual + noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) + noise_pred = noise_pred.sample + attention_store = self.controller.attention_store + + + attention_maps = [] + exemplar_attention_maps = [] + exemplar_attention_maps1 = [] + exemplar_attention_maps2 = [] + exemplar_attention_maps3 = [] + + cross_self_task_attn_maps = [] + cross_self_exe_attn_maps = [] + + # only use 64x64 self-attention + self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096] + prompts=[self.config.prompt], # 这里要改么 + attention_store=self.controller, + res=64, + from_where=("up", "down"), + is_cross=False, + select=0 + ) + self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096] + prompts=[self.config.prompt], # 这里要改么 + attention_store=self.controller, + res=32, + from_where=("up", "down"), + is_cross=False, + select=0 + ) + self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096] + prompts=[self.config.prompt], # 这里要改么 + attention_store=self.controller, + res=16, + from_where=("up", "down"), + is_cross=False, + select=0 + ) + + # cross attention + for res in [32, 16]: + attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] + prompts=[self.config.prompt], # 这里要改么 + attention_store=self.controller, + res=res, + from_where=("up", "down"), + is_cross=True, + select=0 + ) + + task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] + attention_maps.append(task_attn_) + if self.use_box: + exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn + exemplar_attention_maps.append(exemplar_attns) + else: + exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) + exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0) + exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0) + exemplar_attention_maps1.append(exemplar_attns1) + exemplar_attention_maps2.append(exemplar_attns2) + exemplar_attention_maps3.append(exemplar_attns3) + + + scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] + attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) + task_attn_64 = torch.mean(attns, dim=0, keepdim=True) + cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64) + cross_self_task_attn_maps.append(cross_self_task_attn) + + if self.use_box: + scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))] + attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))]) + exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True) + + cross_self_exe_attn = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64) + cross_self_exe_attn_maps.append(cross_self_exe_attn) + else: + scale_factors = [(64 // exemplar_attention_maps1[i].shape[-1]) for i in range(len(exemplar_attention_maps1))] + attns = torch.cat([F.interpolate(exemplar_attention_maps1[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps1))]) + exemplar_attn_64_1 = torch.mean(attns, dim=0, keepdim=True) + + scale_factors = [(64 // exemplar_attention_maps2[i].shape[-1]) for i in range(len(exemplar_attention_maps2))] + attns = torch.cat([F.interpolate(exemplar_attention_maps2[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps2))]) + exemplar_attn_64_2 = torch.mean(attns, dim=0, keepdim=True) + + scale_factors = [(64 // exemplar_attention_maps3[i].shape[-1]) for i in range(len(exemplar_attention_maps3))] + attns = torch.cat([F.interpolate(exemplar_attention_maps3[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps3))]) + exemplar_attn_64_3 = torch.mean(attns, dim=0, keepdim=True) + + + cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64) + cross_self_task_attn_maps.append(cross_self_task_attn) + + # if self.args.merge_exemplar == "average": + cross_self_exe_attn1 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_1) + cross_self_exe_attn2 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_2) + cross_self_exe_attn3 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_3) + exemplar_attn_64 = (exemplar_attn_64_1 + exemplar_attn_64_2 + exemplar_attn_64_3) / 3 + cross_self_exe_attn = (cross_self_exe_attn1 + cross_self_exe_attn2 + cross_self_exe_attn3) / 3 + + exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6) + + attn_stack = [exemplar_attn_64 / 2, cross_self_exe_attn / 2, exemplar_attn_64, cross_self_exe_attn] + attn_stack = torch.cat(attn_stack, dim=1) + + if not self.use_box: + + # cross_self_exe_attn_np = cross_self_exe_attn.detach().squeeze().cpu().numpy() + # boxes = gen_dummy_boxes(cross_self_exe_attn_np, max_boxes=1) + # boxes = boxes.to(self.device) + + loca_out = self.loca_model.forward_before_reg(input_image, boxes) + loca_feature_bf_regression = loca_out["feature_bf_regression"] + attn_out = self.loca_model.forward_reg(loca_out, attn_stack, feature_list[-1]) + pred_density = attn_out["pred"].squeeze().cpu().numpy() + pred_cnt = pred_density.sum().item() + + # resize pred_density to original image size + pred_density_rsz = cv2.resize(pred_density, (width, height), interpolation=cv2.INTER_CUBIC) + pred_density_rsz = pred_density_rsz / pred_density_rsz.sum() * pred_cnt + + return pred_density_rsz, pred_cnt + + +def inference(data_path, box=None, save_path="./example_imgs", visualize=False): + if box is not None: + use_box = True + else: + use_box = False + model = CountingModule(use_box=use_box) + load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_cnt.pth"), strict=True) + model.eval() + with torch.no_grad(): + density_map, cnt = model(data_path, box) + + if visualize: + img = io.imread(data_path) + if len(img.shape) == 3 and img.shape[2] > 3: + img = img[:,:,:3] + if len(img.shape) == 2: + img = np.stack([img]*3, axis=-1) + img_show = img.squeeze() + density_map_show = density_map.squeeze() + os.makedirs(save_path, exist_ok=True) + filename = data_path.split("/")[-1] + img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show)) + fig, ax = plt.subplots(1,2, figsize=(12,6)) + ax[0].imshow(img_show) + ax[0].axis('off') + ax[0].set_title(f"Input image") + ax[1].imshow(img_show) + ax[1].imshow(density_map_show, cmap='jet', alpha=0.5) # Overlay density map with some transparency + ax[1].axis('off') + ax[1].set_title(f"Predicted density map, count: {cnt:.1f}") + plt.tight_layout() + plt.savefig(os.path.join(save_path, filename.split(".")[0]+"_cnt.png"), dpi=300) + plt.close() + return density_map + +def main(): + + inference( + data_path = "example_imgs/1977_Well_F-5_Field_1.png", + # box=[[150, 60, 183, 87]], + save_path = "./example_imgs", + visualize = True + ) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/example_imgs/cnt/047cell.png b/example_imgs/cnt/047cell.png new file mode 100644 index 0000000000000000000000000000000000000000..99059f68d94098226ba46ed2b9b1d31097d2dd6a --- /dev/null +++ b/example_imgs/cnt/047cell.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:3c9fc3d2ab7beecb16d850b1ef82d70a7f7011051d0199f866bc31c42c296d42 +size 72806 diff --git a/example_imgs/cnt/62_10.png b/example_imgs/cnt/62_10.png new file mode 100644 index 0000000000000000000000000000000000000000..6f889109101c69aad1091853eaeeaf3aa9f8110c --- /dev/null +++ b/example_imgs/cnt/62_10.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:b93c916a81eaec1a3511b9379fa293c026bbe74977bc21fc7666a83c92d3b122 +size 91494 diff --git a/example_imgs/cnt/6800-17000_GTEX-XQ3S_Adipose-Subcutaneous.png b/example_imgs/cnt/6800-17000_GTEX-XQ3S_Adipose-Subcutaneous.png new file mode 100644 index 0000000000000000000000000000000000000000..b6cfb8678479eeee082726443fdc5983c439f14a --- /dev/null +++ b/example_imgs/cnt/6800-17000_GTEX-XQ3S_Adipose-Subcutaneous.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:467319789c5b5b6c370a23c126c33044a841a115cf24b79f75106b5521cd5c44 +size 79904 diff --git a/example_imgs/seg/003_img.png b/example_imgs/seg/003_img.png new file mode 100644 index 0000000000000000000000000000000000000000..4839fb37ad72c6aa96391a95371109e2f43c7bef --- /dev/null +++ b/example_imgs/seg/003_img.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:41515cf5d7405135db4656c2cc61b59ab341143bfbee952b44a9542944e8528f +size 302381 diff --git a/example_imgs/seg/1-23 [Scan I08].png b/example_imgs/seg/1-23 [Scan I08].png new file mode 100644 index 0000000000000000000000000000000000000000..ab74165f268f7e6e6df9fbf518c4f7831645e75b --- /dev/null +++ b/example_imgs/seg/1-23 [Scan I08].png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:a96dfccdd794a95c9907b0eedecbd53dee078943d9a3dcdb43e11a36d34f5a1f +size 1418500 diff --git a/example_imgs/seg/10X_B2_Tile-15.aligned.png b/example_imgs/seg/10X_B2_Tile-15.aligned.png new file mode 100644 index 0000000000000000000000000000000000000000..fab2ef5a9546eb5fe43cf652ed7954d7d933eaf6 --- /dev/null +++ b/example_imgs/seg/10X_B2_Tile-15.aligned.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:e8dce16565ccfb055438b0b65d9e70b5be6cc36c61a964eed53d7ec782b5afa3 +size 1519602 diff --git a/example_imgs/seg/1977_Well_F-5_Field_1.png b/example_imgs/seg/1977_Well_F-5_Field_1.png new file mode 100644 index 0000000000000000000000000000000000000000..c99fabcd51cc74e034f1220cd77058257a16a06f --- /dev/null +++ b/example_imgs/seg/1977_Well_F-5_Field_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:145a99e724048ed40db7843e57a1d93cd2e1f6e221d167a29b732740d6302c52 +size 2430598 diff --git a/example_imgs/seg/200972823[5179]_RhoGGG_YAP_TAZ [200972823 Well K6 Field #2].png b/example_imgs/seg/200972823[5179]_RhoGGG_YAP_TAZ [200972823 Well K6 Field #2].png new file mode 100644 index 0000000000000000000000000000000000000000..c0eed6705e815a3027d27a6afe0df28c9a86e0f3 --- /dev/null +++ b/example_imgs/seg/200972823[5179]_RhoGGG_YAP_TAZ [200972823 Well K6 Field #2].png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:56bd7a8df07d66ff5f8dac67aa116efe0869f6c46d9ce77e595535a6acd60ae9 +size 1385832 diff --git a/example_imgs/seg/A172_Phase_C7_1_00d00h00m_1.png b/example_imgs/seg/A172_Phase_C7_1_00d00h00m_1.png new file mode 100644 index 0000000000000000000000000000000000000000..9a3ded9365ee1d87e24e6470240a63e3821eb631 --- /dev/null +++ b/example_imgs/seg/A172_Phase_C7_1_00d00h00m_1.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:f57430b87923f5de9a5799cc84016aeb5d99cd5068481a9fedae2a68fa9bba43 +size 158954 diff --git a/example_imgs/seg/JE2NileRed_oilp22_PMP_101220_011_NR.png b/example_imgs/seg/JE2NileRed_oilp22_PMP_101220_011_NR.png new file mode 100644 index 0000000000000000000000000000000000000000..c40913f3d94925cfc6ade6e24318ad382e509a25 --- /dev/null +++ b/example_imgs/seg/JE2NileRed_oilp22_PMP_101220_011_NR.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bdf31a4eab7826435407f2f88bfeee8f95c2b04d8f579cf6281b7f5838195b03 +size 64356 diff --git a/example_imgs/seg/OpenTest_031.png b/example_imgs/seg/OpenTest_031.png new file mode 100644 index 0000000000000000000000000000000000000000..b8d4af73de635beef8619d6322504248c48add5a --- /dev/null +++ b/example_imgs/seg/OpenTest_031.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:973ecd4ca18c650d630491c1f3531ba4ff20c12a37728dc79f279b26651d0c82 +size 966215 diff --git a/example_imgs/seg/X_24.png b/example_imgs/seg/X_24.png new file mode 100644 index 0000000000000000000000000000000000000000..da49ccd489e97b3ab76375b5e9a6326c62724384 --- /dev/null +++ b/example_imgs/seg/X_24.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:514b2df4bdcdd1d09d1f032284a5c2aaa0572d2f1ec148b256e4bbf5d68eb3c7 +size 102387 diff --git a/example_imgs/seg/exp_A01_G002_0001.oir.png b/example_imgs/seg/exp_A01_G002_0001.oir.png new file mode 100644 index 0000000000000000000000000000000000000000..b07e109b6e5a29a1af897460d99d310d9884e27f --- /dev/null +++ b/example_imgs/seg/exp_A01_G002_0001.oir.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:9c22531659320908a688da277b7f67b70aafb450e035f56e3962ebfd3423140f +size 1685570 diff --git a/example_imgs/tra/tracking_test_sequence.zip b/example_imgs/tra/tracking_test_sequence.zip new file mode 100644 index 0000000000000000000000000000000000000000..475cba40ede1349c45cb57d1102be9b0e3241144 --- /dev/null +++ b/example_imgs/tra/tracking_test_sequence.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:bda69434e3de8103c98313777640acd35fc7501eec4b1528456304142b18797f +size 10392163 diff --git a/example_imgs/tra/tracking_test_sequence2.zip b/example_imgs/tra/tracking_test_sequence2.zip new file mode 100644 index 0000000000000000000000000000000000000000..3eee6976130a88dd9f9da8fc597e06b8c2d5b030 --- /dev/null +++ b/example_imgs/tra/tracking_test_sequence2.zip @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:120cc2a75a4dd571b8f8ee7ea363a9b82a2b4c516376ccf4f287b6864d2dd576 +size 2288296 diff --git a/inference_count.py b/inference_count.py new file mode 100644 index 0000000000000000000000000000000000000000..23152d4f7e3b8cf3f624a89268176f89b7de8f3a --- /dev/null +++ b/inference_count.py @@ -0,0 +1,237 @@ +# inference_count.py +# 计数模型推理模块 - 独立版本 + +import torch +import numpy as np +from PIL import Image +import matplotlib.pyplot as plt +import tempfile +import os +from huggingface_hub import hf_hub_download +from counting import CountingModule + +MODEL = None +DEVICE = torch.device("cpu") + +def load_model(use_box=False): + """ + 加载计数模型 + + Args: + use_box: 是否使用边界框 + + Returns: + model: 加载的模型 + device: 设备 + """ + global MODEL, DEVICE + + try: + print("🔄 Loading counting model...") + + # 初始化模型 + MODEL = CountingModule(use_box=use_box) + + # 从 Hugging Face Hub 下载权重 + ckpt_path = hf_hub_download( + repo_id="phoebe777777/111", + filename="microscopy_matching_cnt.pth", + token=None, + force_download=False + ) + + print(f"✅ Checkpoint downloaded: {ckpt_path}") + + # 加载权重 + MODEL.load_state_dict( + torch.load(ckpt_path, map_location="cpu"), + strict=True + ) + MODEL.eval() + + if torch.cuda.is_available(): + DEVICE = torch.device("cuda") + MODEL.move_to_device(DEVICE) + print("✅ Model moved to CUDA") + else: + DEVICE = torch.device("cpu") + MODEL.move_to_device(DEVICE) + print("✅ Model on CPU") + + print("✅ Counting model loaded successfully") + return MODEL, DEVICE + + except Exception as e: + print(f"❌ Error loading counting model: {e}") + import traceback + traceback.print_exc() + return None, torch.device("cpu") + + +@torch.no_grad() +def run(model, img_path, box=None, device="cpu", visualize=True): + """ + 运行计数推理 + + Args: + model: 计数模型 + img_path: 图像路径 + box: 边界框 [[x1, y1, x2, y2], ...] 或 None + device: 设备 + visualize: 是否生成可视化 + + Returns: + result_dict: { + 'density_map': numpy array, + 'count': float, + 'visualized_path': str (如果 visualize=True) + } + """ + print("DEVICE:", device) + model.move_to_device(device) + model.eval() + if box is not None: + use_box = True + else: + use_box = False + model.use_box = use_box + + if model is None: + return { + 'density_map': None, + 'count': 0, + 'visualized_path': None, + 'error': 'Model not loaded' + } + + try: + print(f"🔄 Running counting inference on {img_path}") + + # 运行推理 (调用你的模型的 forward 方法) + with torch.no_grad(): + density_map, count = model(img_path, box) + + print(f"✅ Counting result: {count:.1f} objects") + + result = { + 'density_map': density_map, + 'count': count, + 'visualized_path': None + } + + # 可视化 + # if visualize: + # viz_path = visualize_result(img_path, density_map, count) + # result['visualized_path'] = viz_path + + return result + + except Exception as e: + print(f"❌ Counting inference error: {e}") + import traceback + traceback.print_exc() + return { + 'density_map': None, + 'count': 0, + 'visualized_path': None, + 'error': str(e) + } + + +def visualize_result(image_path, density_map, count): + """ + 可视化计数结果 (与你原来的可视化代码一致) + + Args: + image_path: 原始图像路径 + density_map: 密度图 (numpy array) + count: 计数值 + + Returns: + output_path: 可视化结果的临时文件路径 + """ + try: + import skimage.io as io + + # 读取原始图像 + img = io.imread(image_path) + + # 处理不同格式的图像 + if len(img.shape) == 3 and img.shape[2] > 3: + img = img[:, :, :3] + if len(img.shape) == 2: + img = np.stack([img]*3, axis=-1) + + # 归一化显示 + img_show = img.squeeze() + density_map_show = density_map.squeeze() + + # 归一化图像 + img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8) + + # 创建可视化 (与你原来的代码一致) + fig, ax = plt.subplots(figsize=(8, 6)) + + # 右图: 密度图叠加 + ax.imshow(img_show) + ax.imshow(density_map_show, cmap='jet', alpha=0.5) + ax.axis('off') + # ax.set_title(f"Predicted density map, count: {count:.1f}") + + plt.tight_layout() + + # 保存到临时文件 + temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png') + plt.savefig(temp_file.name, dpi=300) + plt.close() + + print(f"✅ Visualization saved to {temp_file.name}") + return temp_file.name + + except Exception as e: + print(f"❌ Visualization error: {e}") + import traceback + traceback.print_exc() + return image_path + + +# ===== 测试代码 ===== +if __name__ == "__main__": + print("="*60) + print("Testing Counting Model") + print("="*60) + + # 测试模型加载 + model, device = load_model(use_box=False) + + if model is not None: + print("\n" + "="*60) + print("Model loaded successfully, testing inference...") + print("="*60) + + # 测试推理 + test_image = "example_imgs/1977_Well_F-5_Field_1.png" + + if os.path.exists(test_image): + result = run( + model, + test_image, + box=None, + device=device, + visualize=True + ) + + if 'error' not in result: + print("\n" + "="*60) + print("Inference Results:") + print("="*60) + print(f"Count: {result['count']:.1f}") + print(f"Density map shape: {result['density_map'].shape}") + if result['visualized_path']: + print(f"Visualization saved to: {result['visualized_path']}") + else: + print(f"\n❌ Inference failed: {result['error']}") + else: + print(f"\n⚠️ Test image not found: {test_image}") + else: + print("\n❌ Model loading failed") diff --git a/inference_seg.py b/inference_seg.py new file mode 100644 index 0000000000000000000000000000000000000000..12e3bd04188565700b686c8e018961c05e6b6a20 --- /dev/null +++ b/inference_seg.py @@ -0,0 +1,87 @@ +import torch +import numpy as np +from huggingface_hub import hf_hub_download +from segmentation import SegmentationModule + +MODEL = None +DEVICE = torch.device("cpu") + +def load_model(use_box=False): + global MODEL, DEVICE + MODEL = SegmentationModule(use_box=use_box) + + ckpt_path = hf_hub_download( + repo_id="phoebe777777/111", + filename="microscopy_matching_seg.pth", + token=None, + force_download=False + ) + MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False) + MODEL.eval() + if torch.cuda.is_available(): + DEVICE = torch.device("cuda") + MODEL.move_to_device(DEVICE) + print("✅ Model moved to CUDA") + else: + DEVICE = torch.device("cpu") + MODEL.move_to_device(DEVICE) + print("✅ Model on CPU") + return MODEL, DEVICE + + +@torch.no_grad() +def run(model, img_path, box=None, device="cpu"): + print("DEVICE:", device) + model.move_to_device(device) + model.eval() + with torch.no_grad(): + if box is not None: + use_box = True + else: + use_box = False + model.use_box = use_box + output = model(img_path, box=box) + mask = output + return mask +# import os +# import torch +# import numpy as np +# from huggingface_hub import hf_hub_download +# from segmentation import SegmentationModule + +# MODEL = None +# DEVICE = torch.device("cpu") + +# def load_model(use_box=False): +# global MODEL, DEVICE + +# # === 优化1: 使用 /data 缓存模型,避免写入 .cache === +# cache_dir = "/data/cellseg_model_cache" +# os.makedirs(cache_dir, exist_ok=True) + +# ckpt_path = hf_hub_download( +# repo_id="Shengxiao0709/cellsegmodel", +# filename="microscopy_matching_seg.pth", +# token=None, +# local_dir=cache_dir, # ✅ 下载到 /data +# local_dir_use_symlinks=False, # ✅ 避免软链接问题 +# force_download=False # ✅ 已存在时不重复下载 +# ) + +# # === 优化2: 加载模型 === +# MODEL = SegmentationModule(use_box=use_box) +# state_dict = torch.load(ckpt_path, map_location="cpu") +# MODEL.load_state_dict(state_dict, strict=False) +# MODEL.eval() + +# DEVICE = torch.device("cpu") +# print(f"✅ Model loaded from {ckpt_path}") +# return MODEL, DEVICE + + +# @torch.no_grad() +# def run(model, img_path, box=None, device="cpu"): +# output = model(img_path, box=box) +# mask = output["pred"] +# mask = (mask > 0).astype(np.uint8) +# return mask \ No newline at end of file diff --git a/inference_track.py b/inference_track.py new file mode 100644 index 0000000000000000000000000000000000000000..c2df0d550ea6db5cc1f86f578632b4e8f17424b3 --- /dev/null +++ b/inference_track.py @@ -0,0 +1,202 @@ +# inference_track.py +# 视频跟踪模型推理模块 + +import torch +import numpy as np +import os +from pathlib import Path +from tqdm import tqdm +from huggingface_hub import hf_hub_download +from tracking_one import TrackingModule +from models.tra_post_model.trackastra.tracking import graph_to_ctc + +MODEL = None +DEVICE = torch.device("cpu") + +def load_model(use_box=False): + """ + 加载跟踪模型 + + Args: + use_box: 是否使用边界框 + + Returns: + model: 加载的模型 + device: 设备 + """ + global MODEL, DEVICE + + try: + print("🔄 Loading tracking model...") + + # 初始化模型 + MODEL = TrackingModule(use_box=use_box) + + # 从 Hugging Face Hub 下载权重 + ckpt_path = hf_hub_download( + repo_id="phoebe777777/111", + filename="microscopy_matching_tra.pth", + token=None, + force_download=False + ) + + print(f"✅ Checkpoint downloaded: {ckpt_path}") + + # 加载权重 + MODEL.load_state_dict( + torch.load(ckpt_path, map_location="cpu"), + strict=True + ) + MODEL.eval() + + # 设置设备 + if torch.cuda.is_available(): + DEVICE = torch.device("cuda") + MODEL.move_to_device(DEVICE) + print("✅ Model moved to CUDA") + else: + DEVICE = torch.device("cpu") + MODEL.move_to_device(DEVICE) + print("✅ Model on CPU") + + print("✅ Tracking model loaded successfully") + return MODEL, DEVICE + + except Exception as e: + print(f"❌ Error loading tracking model: {e}") + import traceback + traceback.print_exc() + return None, torch.device("cpu") + + +@torch.no_grad() +def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"): + """ + 运行视频跟踪推理 + + Args: + model: 跟踪模型 + video_dir: 视频帧序列目录 (包含连续的图像文件) + box: 边界框 (可选) + device: 设备 + output_dir: 输出目录 + + Returns: + result_dict: { + 'track_graph': TrackGraph对象, + 'masks': 分割掩码数组 (T, H, W), + 'output_dir': 输出目录路径, + 'num_tracks': 跟踪轨迹数量 + } + """ + if model is None: + return { + 'track_graph': None, + 'masks': None, + 'output_dir': None, + 'num_tracks': 0, + 'error': 'Model not loaded' + } + + try: + print(f"🔄 Running tracking inference on {video_dir}") + + # 运行跟踪 + track_graph, masks = model.track( + file_dir=video_dir, + boxes=box, + mode="greedy", # 可选: "greedy", "greedy_nodiv", "ilp" + dataname="tracking_result" + ) + + # 创建输出目录 + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # 转换为CTC格式并保存 + print("🔄 Converting to CTC format...") + ctc_tracks, masks_tracked = graph_to_ctc( + track_graph, + masks, + outdir=output_dir, + ) + print(f"✅ CTC results saved to {output_dir}") + + # num_tracks = len(track_graph.tracks()) + + print(f"✅ Tracking completed") + + result = { + 'track_graph': track_graph, + 'masks': masks, + 'masks_tracked': masks_tracked, + 'output_dir': output_dir, + # 'num_tracks': num_tracks + } + + return result + + except Exception as e: + print(f"❌ Tracking inference error: {e}") + import traceback + traceback.print_exc() + return { + 'track_graph': None, + 'masks': None, + 'output_dir': None, + 'num_tracks': 0, + 'error': str(e) + } + + +def visualize_tracking_result(masks_tracked, output_path): + """ + 可视化跟踪结果 (可选) + + Args: + masks_tracked: 跟踪后的掩码 (T, H, W) + output_path: 输出视频路径 + + Returns: + output_path: 视频文件路径 + """ + try: + import cv2 + import matplotlib.pyplot as plt + from matplotlib import cm + + # 获取时间帧数 + T, H, W = masks_tracked.shape + + # 创建颜色映射 + unique_ids = np.unique(masks_tracked) + num_colors = len(unique_ids) + cmap = cm.get_cmap('tab20', num_colors) + + # 创建视频写入器 + fourcc = cv2.VideoWriter_fourcc(*'mp4v') + out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H)) + + for t in range(T): + frame = masks_tracked[t] + + # 创建彩色图像 + colored_frame = np.zeros((H, W, 3), dtype=np.uint8) + for i, obj_id in enumerate(unique_ids): + if obj_id == 0: + continue + mask = (frame == obj_id) + color = np.array(cmap(i % num_colors)[:3]) * 255 + colored_frame[mask] = color + + # 转换为BGR (OpenCV格式) + colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR) + out.write(colored_frame_bgr) + + out.release() + print(f"✅ Visualization saved to {output_path}") + return output_path + + except Exception as e: + print(f"❌ Visualization error: {e}") + return None diff --git a/models/.DS_Store b/models/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..9e34d5273b298120afcca724a30cb46f93f60ff5 Binary files /dev/null and b/models/.DS_Store differ diff --git a/models/enc_model/__init__.py b/models/enc_model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/models/enc_model/backbone.py b/models/enc_model/backbone.py new file mode 100644 index 0000000000000000000000000000000000000000..5f55124e28ace3d8f0a46260d4667edaee7d10d9 --- /dev/null +++ b/models/enc_model/backbone.py @@ -0,0 +1,64 @@ +import torch +from torch import nn +from torch.nn import functional as F +from torchvision import models +from torchvision.ops.misc import FrozenBatchNorm2d + + +class Backbone(nn.Module): + + def __init__( + self, + name: str, + pretrained: bool, + dilation: bool, + reduction: int, + swav: bool, + requires_grad: bool + ): + + super(Backbone, self).__init__() + + resnet = getattr(models, name)( + replace_stride_with_dilation=[False, False, dilation], + pretrained=pretrained, norm_layer=FrozenBatchNorm2d + ) + + self.backbone = resnet + self.reduction = reduction + + if name == 'resnet50' and swav: + checkpoint = torch.hub.load_state_dict_from_url( + 'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar', + map_location="cpu" + ) + state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()} + self.backbone.load_state_dict(state_dict, strict=False) + + # concatenation of layers 2, 3 and 4 + self.num_channels = 896 if name in ['resnet18', 'resnet34'] else 3584 + + for n, param in self.backbone.named_parameters(): + if 'layer2' not in n and 'layer3' not in n and 'layer4' not in n: + param.requires_grad_(False) + else: + param.requires_grad_(requires_grad) + + def forward(self, x): + size = x.size(-2) // self.reduction, x.size(-1) // self.reduction + x = self.backbone.conv1(x) + x = self.backbone.bn1(x) + x = self.backbone.relu(x) + x = self.backbone.maxpool(x) + + x = self.backbone.layer1(x) + x = layer2 = self.backbone.layer2(x) + x = layer3 = self.backbone.layer3(x) + x = layer4 = self.backbone.layer4(x) + + x = torch.cat([ + F.interpolate(f, size=size, mode='bilinear', align_corners=True) + for f in [layer2, layer3, layer4] + ], dim=1) + + return x diff --git a/models/enc_model/loca.py b/models/enc_model/loca.py new file mode 100644 index 0000000000000000000000000000000000000000..37803bfda7bbc8cab658b05f70ae378700ec841f --- /dev/null +++ b/models/enc_model/loca.py @@ -0,0 +1,232 @@ +from .backbone import Backbone +from .transformer import TransformerEncoder +from .ope import OPEModule +from .positional_encoding import PositionalEncodingsFixed +from .regression_head import DensityMapRegressor + +import torch +from torch import nn +from torch.nn import functional as F + + +class LOCA(nn.Module): + + def __init__( + self, + image_size: int, + num_encoder_layers: int, + num_ope_iterative_steps: int, + num_objects: int, + emb_dim: int, + num_heads: int, + kernel_dim: int, + backbone_name: str, + swav_backbone: bool, + train_backbone: bool, + reduction: int, + dropout: float, + layer_norm_eps: float, + mlp_factor: int, + norm_first: bool, + activation: nn.Module, + norm: bool, + zero_shot: bool, + ): + + super(LOCA, self).__init__() + + self.emb_dim = emb_dim + self.num_objects = num_objects + self.reduction = reduction + self.kernel_dim = kernel_dim + self.image_size = image_size + self.zero_shot = zero_shot + self.num_heads = num_heads + self.num_encoder_layers = num_encoder_layers + + self.backbone = Backbone( + backbone_name, pretrained=True, dilation=False, reduction=reduction, + swav=swav_backbone, requires_grad=train_backbone + ) + self.input_proj = nn.Conv2d( + self.backbone.num_channels, emb_dim, kernel_size=1 + ) + + if num_encoder_layers > 0: + self.encoder = TransformerEncoder( + num_encoder_layers, emb_dim, num_heads, dropout, layer_norm_eps, + mlp_factor, norm_first, activation, norm + ) + + self.ope = OPEModule( + num_ope_iterative_steps, emb_dim, kernel_dim, num_objects, num_heads, + reduction, layer_norm_eps, mlp_factor, norm_first, activation, norm, zero_shot + ) + + self.regression_head = DensityMapRegressor(emb_dim, reduction) + self.aux_heads = nn.ModuleList([ + DensityMapRegressor(emb_dim, reduction) + for _ in range(num_ope_iterative_steps - 1) + ]) + + self.pos_emb = PositionalEncodingsFixed(emb_dim) + + self.attn_norm = nn.LayerNorm(normalized_shape=(64, 64)) + self.fuse = nn.Sequential( + nn.Conv2d(324, 256, kernel_size=1, stride=1), + nn.LeakyReLU(), + nn.LayerNorm((64, 64)) + ) + + # self.fuse1 = nn.Sequential( + # nn.Conv2d(322, 256, kernel_size=1, stride=1), + # nn.LeakyReLU(), + # nn.LayerNorm((64, 64)) + # ) + + def forward_before_reg(self, x, bboxes): + num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects + # backbone + backbone_features = self.backbone(x) + # prepare the encoder input + src = self.input_proj(backbone_features) + bs, c, h, w = src.size() + pos_emb = self.pos_emb(bs, h, w, src.device).flatten(2).permute(2, 0, 1) + src = src.flatten(2).permute(2, 0, 1) + + # push through the encoder + if self.num_encoder_layers > 0: + image_features = self.encoder(src, pos_emb, src_key_padding_mask=None, src_mask=None) + else: + image_features = src + + # prepare OPE input + f_e = image_features.permute(1, 2, 0).reshape(-1, self.emb_dim, h, w) + + all_prototypes = self.ope(f_e, pos_emb, bboxes) # [3, 27, 1, 256] + + outputs = list() + response_maps_list = [] + for i in range(all_prototypes.size(0)): + prototypes = all_prototypes[i, ...].permute(1, 0, 2).reshape( + bs, num_objects, self.kernel_dim, self.kernel_dim, -1 + ).permute(0, 1, 4, 2, 3).flatten(0, 2)[:, None, ...] # [768, 1, 3, 3] + + response_maps = F.conv2d( + torch.cat([f_e for _ in range(num_objects)], dim=1).flatten(0, 1).unsqueeze(0), + prototypes, + bias=None, + padding=self.kernel_dim // 2, + groups=prototypes.size(0) + ).view( + bs, num_objects, self.emb_dim, h, w + ).max(dim=1)[0] + + # # send through regression heads + # if i == all_prototypes.size(0) - 1: + # predicted_dmaps = self.regression_head(response_maps) + # else: + # predicted_dmaps = self.aux_heads[i](response_maps) + # outputs.append(predicted_dmaps) + response_maps_list.append(response_maps) + + out = { + # "pred": outputs[-1], + "feature_bf_regression": response_maps_list[-1], + # "aux_pred": outputs[:-1], + "aux_feature_bf_regression": response_maps_list[:-1] + } + + return out + + def forward_reg(self, response_maps, attn_stack, unet_feature): + attn_stack = self.attn_norm(attn_stack) + attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) + unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64] + unet_feature = unet_feature * attn_stack_mean + if unet_feature.shape[1] == 322: + unet_feature = self.fuse1(unet_feature) + else: + unet_feature = self.fuse(unet_feature) + + response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]] + + outputs = [] + for i in range(len(response_maps)): + response_map = response_maps[i] + unet_feature + if i == len(response_maps) - 1: + predicted_dmaps = self.regression_head(response_map) + else: + predicted_dmaps = self.aux_heads[i](response_map) + outputs.append(predicted_dmaps) + + return {"pred": outputs[-1], "aux_pred": outputs[:-1]} + + def forward_reg1(self, response_maps, self_attn): + # attn_stack = self.attn_norm(attn_stack) + # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) + # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64] + # unet_feature = unet_feature * attn_stack_mean + # if unet_feature.shape[1] == 322: + # unet_feature = self.fuse1(unet_feature) + # else: + # unet_feature = self.fuse(unet_feature) + + + + response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]] + + outputs = [] + for i in range(len(response_maps)): + response_map = response_maps[i] + self_attn + if i == len(response_maps) - 1: + predicted_dmaps = self.regression_head(response_map) + else: + predicted_dmaps = self.aux_heads[i](response_map) + outputs.append(predicted_dmaps) + + return {"pred": outputs[-1], "aux_pred": outputs[:-1]} + + def forward_reg_without_unet(self, response_maps, attn_stack): + # attn_stack = self.attn_norm(attn_stack) + attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) + + response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]] + + outputs = [] + for i in range(len(response_maps)): + response_map = response_maps[i] * attn_stack_mean * 0.5 + response_maps[i] + if i == len(response_maps) - 1: + predicted_dmaps = self.regression_head(response_map) + else: + predicted_dmaps = self.aux_heads[i](response_map) + outputs.append(predicted_dmaps) + + return {"pred": outputs[-1], "aux_pred": outputs[:-1]} + + +def build_model(args): + + assert args.backbone in ['resnet18', 'resnet50', 'resnet101'] + assert args.reduction in [4, 8, 16] + + return LOCA( + image_size=args.image_size, + num_encoder_layers=args.num_enc_layers, + num_ope_iterative_steps=args.num_ope_iterative_steps, + num_objects=args.num_objects, + zero_shot=args.zero_shot, + emb_dim=args.emb_dim, + num_heads=args.num_heads, + kernel_dim=args.kernel_dim, + backbone_name=args.backbone, + swav_backbone=args.swav_backbone, + train_backbone=args.backbone_lr > 0, + reduction=args.reduction, + dropout=args.dropout, + layer_norm_eps=1e-5, + mlp_factor=8, + norm_first=args.pre_norm, + activation=nn.GELU, + norm=True, + ) diff --git a/models/enc_model/loca_args.py b/models/enc_model/loca_args.py new file mode 100644 index 0000000000000000000000000000000000000000..5352fd2b8bf217444965cdee9eb5ce98019e0ded --- /dev/null +++ b/models/enc_model/loca_args.py @@ -0,0 +1,44 @@ +import argparse + + +def get_argparser(): + + parser = argparse.ArgumentParser("LOCA parser", add_help=False) + + parser.add_argument('--model_name', default='loca_few_shot', type=str) + parser.add_argument( + '--data_path', + default='./data/FSC147_384_V2', + type=str + ) + parser.add_argument( + '--model_path', + default='ckpt', + type=str + ) + parser.add_argument('--backbone', default='resnet50', type=str) + parser.add_argument('--swav_backbone', action='store_true', default=True) + parser.add_argument('--reduction', default=8, type=int) + parser.add_argument('--image_size', default=512, type=int) + parser.add_argument('--num_enc_layers', default=3, type=int) + parser.add_argument('--num_ope_iterative_steps', default=3, type=int) + parser.add_argument('--emb_dim', default=256, type=int) + parser.add_argument('--num_heads', default=8, type=int) + parser.add_argument('--kernel_dim', default=3, type=int) + parser.add_argument('--num_objects', default=3, type=int) + parser.add_argument('--epochs', default=200, type=int) + parser.add_argument('--resume_training', action='store_true') + parser.add_argument('--lr', default=1e-4, type=float) + parser.add_argument('--backbone_lr', default=0, type=float) + parser.add_argument('--lr_drop', default=200, type=int) + parser.add_argument('--weight_decay', default=1e-4, type=float) + parser.add_argument('--batch_size', default=1, type=int) + parser.add_argument('--dropout', default=0.1, type=float) + parser.add_argument('--num_workers', default=8, type=int) + parser.add_argument('--max_grad_norm', default=0.1, type=float) + parser.add_argument('--aux_weight', default=0.3, type=float) + parser.add_argument('--tiling_p', default=0.5, type=float) + parser.add_argument('--zero_shot', action='store_true') + parser.add_argument('--pre_norm', action='store_true', default=True) + + return parser diff --git a/models/enc_model/mlp.py b/models/enc_model/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..7d39aedb5072ed9ca7d55774304e7f40aea9a60d --- /dev/null +++ b/models/enc_model/mlp.py @@ -0,0 +1,23 @@ +from torch import nn + + +class MLP(nn.Module): + + def __init__( + self, + input_dim: int, + hidden_dim: int, + dropout: float, + activation: nn.Module + ): + super(MLP, self).__init__() + + self.linear1 = nn.Linear(input_dim, hidden_dim) + self.linear2 = nn.Linear(hidden_dim, input_dim) + self.dropout = nn.Dropout(dropout) + self.activation = activation() + + def forward(self, x): + return ( + self.linear2(self.dropout(self.activation(self.linear1(x)))) + ) diff --git a/models/enc_model/ope.py b/models/enc_model/ope.py new file mode 100644 index 0000000000000000000000000000000000000000..7e8caf9312c195dd1935214faf0802ee4d2c8931 --- /dev/null +++ b/models/enc_model/ope.py @@ -0,0 +1,245 @@ +from .mlp import MLP +from .positional_encoding import PositionalEncodingsFixed + +import torch +from torch import nn + +from torchvision.ops import roi_align + + +class OPEModule(nn.Module): + + def __init__( + self, + num_iterative_steps: int, + emb_dim: int, + kernel_dim: int, + num_objects: int, + num_heads: int, + reduction: int, + layer_norm_eps: float, + mlp_factor: int, + norm_first: bool, + activation: nn.Module, + norm: bool, + zero_shot: bool, + ): + + super(OPEModule, self).__init__() + + self.num_iterative_steps = num_iterative_steps + self.zero_shot = zero_shot + self.kernel_dim = kernel_dim + self.num_objects = num_objects + self.emb_dim = emb_dim + self.reduction = reduction + + if num_iterative_steps > 0: + self.iterative_adaptation = IterativeAdaptationModule( + num_layers=num_iterative_steps, emb_dim=emb_dim, num_heads=num_heads, + dropout=0, layer_norm_eps=layer_norm_eps, + mlp_factor=mlp_factor, norm_first=norm_first, + activation=activation, norm=norm, + zero_shot=zero_shot + ) + + if not self.zero_shot: + self.shape_or_objectness = nn.Sequential( + nn.Linear(2, 64), + nn.ReLU(), + nn.Linear(64, emb_dim), + nn.ReLU(), + nn.Linear(emb_dim, self.kernel_dim**2 * emb_dim) + ) + else: + self.shape_or_objectness = nn.Parameter( + torch.empty((self.num_objects, self.kernel_dim**2, emb_dim)) + ) + nn.init.normal_(self.shape_or_objectness) + + self.pos_emb = PositionalEncodingsFixed(emb_dim) + + def forward(self, f_e, pos_emb, bboxes): + bs, _, h, w = f_e.size() + # extract the shape features or objectness + if not self.zero_shot: + box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device) + box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0] + box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1] + shape_or_objectness = self.shape_or_objectness(box_hw).reshape( + bs, -1, self.kernel_dim ** 2, self.emb_dim + ).flatten(1, 2).transpose(0, 1) + else: + shape_or_objectness = self.shape_or_objectness.expand( + bs, -1, -1, -1 + ).flatten(1, 2).transpose(0, 1) + + # if not zero shot add appearance + if not self.zero_shot: + # reshape bboxes into the format suitable for roi_align + num_of_boxes = bboxes.size(1) + bboxes = torch.cat([ + torch.arange( + bs, requires_grad=False + ).to(bboxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1), + bboxes.flatten(0, 1), + ], dim=1) + appearance = roi_align( + f_e, + boxes=bboxes, output_size=self.kernel_dim, + spatial_scale=1.0 / self.reduction, aligned=True + ).permute(0, 2, 3, 1).reshape( + bs, num_of_boxes * self.kernel_dim ** 2, -1 + ).transpose(0, 1) + else: + num_of_boxes = self.num_objects + appearance = None + + query_pos_emb = self.pos_emb( + bs, self.kernel_dim, self.kernel_dim, f_e.device + ).flatten(2).permute(2, 0, 1).repeat(num_of_boxes, 1, 1) + + if self.num_iterative_steps > 0: + memory = f_e.flatten(2).permute(2, 0, 1) + all_prototypes = self.iterative_adaptation( + shape_or_objectness, appearance, memory, pos_emb, query_pos_emb + ) + else: + if shape_or_objectness is not None and appearance is not None: + all_prototypes = (shape_or_objectness + appearance).unsqueeze(0) + else: + all_prototypes = ( + shape_or_objectness if shape_or_objectness is not None else appearance + ).unsqueeze(0) + + return all_prototypes + + + +class IterativeAdaptationModule(nn.Module): + + def __init__( + self, + num_layers: int, + emb_dim: int, + num_heads: int, + dropout: float, + layer_norm_eps: float, + mlp_factor: int, + norm_first: bool, + activation: nn.Module, + norm: bool, + zero_shot: bool + ): + + super(IterativeAdaptationModule, self).__init__() + + self.layers = nn.ModuleList([ + IterativeAdaptationLayer( + emb_dim, num_heads, dropout, layer_norm_eps, + mlp_factor, norm_first, activation, zero_shot + ) for i in range(num_layers) + ]) + + self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity() + + def forward( + self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask=None, memory_mask=None, + tgt_key_padding_mask=None, memory_key_padding_mask=None + ): + + output = tgt + outputs = list() + for i, layer in enumerate(self.layers): + output = layer( + output, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask + ) + outputs.append(self.norm(output)) + + return torch.stack(outputs) + + +class IterativeAdaptationLayer(nn.Module): + + def __init__( + self, + emb_dim: int, + num_heads: int, + dropout: float, + layer_norm_eps: float, + mlp_factor: int, + norm_first: bool, + activation: nn.Module, + zero_shot: bool + ): + super(IterativeAdaptationLayer, self).__init__() + + self.norm_first = norm_first + self.zero_shot = zero_shot + + if not self.zero_shot: + self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps) + self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps) + self.norm3 = nn.LayerNorm(emb_dim, layer_norm_eps) + if not self.zero_shot: + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + self.dropout3 = nn.Dropout(dropout) + + if not self.zero_shot: + self.self_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout) + self.enc_dec_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout) + + self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation) + + def with_emb(self, x, emb): + return x if emb is None else x + emb + + def forward( + self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask, + tgt_key_padding_mask, memory_key_padding_mask + ): + if self.norm_first: + if not self.zero_shot: + tgt_norm = self.norm1(tgt) + tgt = tgt + self.dropout1(self.self_attn( + query=self.with_emb(tgt_norm, query_pos_emb), + key=self.with_emb(appearance, query_pos_emb), + value=appearance, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask + )[0]) + + tgt_norm = self.norm2(tgt) + tgt = tgt + self.dropout2(self.enc_dec_attn( + query=self.with_emb(tgt_norm, query_pos_emb), + key=memory+pos_emb, + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask + )[0]) + tgt_norm = self.norm3(tgt) + tgt = tgt + self.dropout3(self.mlp(tgt_norm)) + + else: + if not self.zero_shot: + tgt = self.norm1(tgt + self.dropout1(self.self_attn( + query=self.with_emb(tgt, query_pos_emb), + key=self.with_emb(appearance), + value=appearance, + attn_mask=tgt_mask, + key_padding_mask=tgt_key_padding_mask + )[0])) + + tgt = self.norm2(tgt + self.dropout2(self.enc_dec_attn( + query=self.with_emb(tgt, query_pos_emb), + key=memory+pos_emb, + value=memory, + attn_mask=memory_mask, + key_padding_mask=memory_key_padding_mask + )[0])) + + tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt))) + + return tgt diff --git a/models/enc_model/positional_encoding.py b/models/enc_model/positional_encoding.py new file mode 100644 index 0000000000000000000000000000000000000000..9ca20f8e000c1f134b6d5108fe5ab626292a5519 --- /dev/null +++ b/models/enc_model/positional_encoding.py @@ -0,0 +1,30 @@ +import torch +from torch import nn + + +class PositionalEncodingsFixed(nn.Module): + + def __init__(self, emb_dim, temperature=10000): + + super(PositionalEncodingsFixed, self).__init__() + + self.emb_dim = emb_dim + self.temperature = temperature + + def _1d_pos_enc(self, mask, dim): + temp = torch.arange(self.emb_dim // 2).float().to(mask.device) + temp = self.temperature ** (2 * (temp.div(2, rounding_mode='floor')) / self.emb_dim) + + enc = (~mask).cumsum(dim).float().unsqueeze(-1) / temp + enc = torch.stack([ + enc[..., 0::2].sin(), enc[..., 1::2].cos() + ], dim=-1).flatten(-2) + + return enc + + def forward(self, bs, h, w, device): + mask = torch.zeros(bs, h, w, dtype=torch.bool, requires_grad=False, device=device) + x = self._1d_pos_enc(mask, dim=2) + y = self._1d_pos_enc(mask, dim=1) + + return torch.cat([y, x], dim=3).permute(0, 3, 1, 2) diff --git a/models/enc_model/regression_head.py b/models/enc_model/regression_head.py new file mode 100644 index 0000000000000000000000000000000000000000..7fb65fb4fb579660bd59f2f6a0b7466fd24adaec --- /dev/null +++ b/models/enc_model/regression_head.py @@ -0,0 +1,92 @@ +from torch import nn +import torch + + +class UpsamplingLayer(nn.Module): + + def __init__(self, in_channels, out_channels, leaky=True): + + super(UpsamplingLayer, self).__init__() + + self.layer = nn.Sequential( + nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1), + nn.LeakyReLU() if leaky else nn.ReLU(), + nn.UpsamplingBilinear2d(scale_factor=2) + ) + + def forward(self, x): + return self.layer(x) + + +class DensityMapRegressor(nn.Module): + + def __init__(self, in_channels, reduction): + + super(DensityMapRegressor, self).__init__() + + if reduction == 8: + self.regressor = nn.Sequential( + UpsamplingLayer(in_channels, 128), + UpsamplingLayer(128, 64), + UpsamplingLayer(64, 32), + nn.Conv2d(32, 1, kernel_size=1), + nn.LeakyReLU() + ) + elif reduction == 16: + self.regressor = nn.Sequential( + UpsamplingLayer(in_channels, 128), + UpsamplingLayer(128, 64), + UpsamplingLayer(64, 32), + UpsamplingLayer(32, 16), + nn.Conv2d(16, 1, kernel_size=1), + nn.LeakyReLU() + ) + + self.reset_parameters() + + def forward(self, x): + return self.regressor(x) + + def reset_parameters(self): + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.normal_(module.weight, std=0.01) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + +class DensityMapRegressor_(nn.Module): + + def __init__(self, in_channels, reduction): + + super(DensityMapRegressor, self).__init__() + + if reduction == 8: + self.regressor = nn.Sequential( + UpsamplingLayer(in_channels, 128), + UpsamplingLayer(128, 64), + UpsamplingLayer(64, 32), + nn.Conv2d(32, 1, kernel_size=1), + nn.LeakyReLU() + ) + elif reduction == 16: + self.regressor = nn.Sequential( + UpsamplingLayer(in_channels, 128), + UpsamplingLayer(128, 64), + UpsamplingLayer(64, 32), + UpsamplingLayer(32, 16), + nn.Conv2d(16, 1, kernel_size=1), + nn.LeakyReLU() + ) + + self.reset_parameters() + + def forward(self, x): + return self.regressor(x) + + def reset_parameters(self): + for module in self.modules(): + if isinstance(module, nn.Conv2d): + nn.init.normal_(module.weight, std=0.01) + if module.bias is not None: + nn.init.constant_(module.bias, 0) diff --git a/models/enc_model/transformer.py b/models/enc_model/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..5d976d0c55f7b6798900302994a459af67db2ff7 --- /dev/null +++ b/models/enc_model/transformer.py @@ -0,0 +1,94 @@ +from .mlp import MLP + +from torch import nn + + +class TransformerEncoder(nn.Module): + + def __init__( + self, + num_layers: int, + emb_dim: int, + num_heads: int, + dropout: float, + layer_norm_eps: float, + mlp_factor: int, + norm_first: bool, + activation: nn.Module, + norm: bool, + ): + + super(TransformerEncoder, self).__init__() + + self.layers = nn.ModuleList([ + TransformerEncoderLayer( + emb_dim, num_heads, dropout, layer_norm_eps, + mlp_factor, norm_first, activation + ) for _ in range(num_layers) + ]) + + self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity() + + def forward(self, src, pos_emb, src_mask, src_key_padding_mask): + output = src + for layer in self.layers: + output = layer(output, pos_emb, src_mask, src_key_padding_mask) + return self.norm(output) + + +class TransformerEncoderLayer(nn.Module): + + def __init__( + self, + emb_dim: int, + num_heads: int, + dropout: float, + layer_norm_eps: float, + mlp_factor: int, + norm_first: bool, + activation: nn.Module, + ): + super(TransformerEncoderLayer, self).__init__() + + self.norm_first = norm_first + + self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps) + self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps) + self.dropout1 = nn.Dropout(dropout) + self.dropout2 = nn.Dropout(dropout) + + self.self_attn = nn.MultiheadAttention( + emb_dim, num_heads, dropout + ) + self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation) + + def with_emb(self, x, emb): + return x if emb is None else x + emb + + def forward(self, src, pos_emb, src_mask, src_key_padding_mask): + if self.norm_first: + src_norm = self.norm1(src) + q = k = src_norm + pos_emb + src = src + self.dropout1(self.self_attn( + query=q, + key=k, + value=src_norm, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask + )[0]) + + src_norm = self.norm2(src) + src = src + self.dropout2(self.mlp(src_norm)) + else: + q = k = src + pos_emb + src = self.norm1(src + self.dropout1(self.self_attn( + query=q, + key=k, + value=src, + attn_mask=src_mask, + key_padding_mask=src_key_padding_mask + )[0])) + + src = self.norm2(src + self.dropout2(self.mlp(src))) + + return src diff --git a/models/enc_model/unet_parts.py b/models/enc_model/unet_parts.py new file mode 100644 index 0000000000000000000000000000000000000000..986ba251f4b15aadd5261093074d2581f3d3f5b8 --- /dev/null +++ b/models/enc_model/unet_parts.py @@ -0,0 +1,77 @@ +""" Parts of the U-Net model """ + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +class DoubleConv(nn.Module): + """(convolution => [BN] => ReLU) * 2""" + + def __init__(self, in_channels, out_channels, mid_channels=None): + super().__init__() + if not mid_channels: + mid_channels = out_channels + self.double_conv = nn.Sequential( + nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(mid_channels), + nn.ReLU(inplace=True), + nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False), + nn.BatchNorm2d(out_channels), + nn.ReLU(inplace=True) + ) + + def forward(self, x): + return self.double_conv(x) + + +class Down(nn.Module): + """Downscaling with maxpool then double conv""" + + def __init__(self, in_channels, out_channels): + super().__init__() + self.maxpool_conv = nn.Sequential( + nn.MaxPool2d(2), + DoubleConv(in_channels, out_channels) + ) + + def forward(self, x): + return self.maxpool_conv(x) + + +class Up(nn.Module): + """Upscaling then double conv""" + + def __init__(self, in_channels, out_channels, bilinear=True): + super().__init__() + + # if bilinear, use the normal convolutions to reduce the number of channels + if bilinear: + self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True) + self.conv = DoubleConv(in_channels, out_channels, in_channels // 2) + else: + self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2) + self.conv = DoubleConv(in_channels, out_channels) + + def forward(self, x1, x2): + x1 = self.up(x1) + # input is CHW + diffY = x2.size()[2] - x1.size()[2] + diffX = x2.size()[3] - x1.size()[3] + + x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2, + diffY // 2, diffY - diffY // 2]) + # if you have padding issues, see + # https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a + # https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd + x = torch.cat([x2, x1], dim=1) + return self.conv(x) + + +class OutConv(nn.Module): + def __init__(self, in_channels, out_channels): + super(OutConv, self).__init__() + self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1) + + def forward(self, x): + return self.conv(x) diff --git a/models/model.py b/models/model.py new file mode 100644 index 0000000000000000000000000000000000000000..de8006f569f04511c76cfe677ae9de7885c18c1d --- /dev/null +++ b/models/model.py @@ -0,0 +1,653 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import os +import clip +import sys +import numpy as np +from models.seg_post_model.cellpose.models import CellposeModel + +from torchvision.ops import roi_align +def crop_roi_feat(feat, boxes): + """ + feat: 1 x c x h x w + boxes: m x 4, 4: [y_tl, x_tl, y_br, x_br] + """ + _, _, h, w = feat.shape + out_stride = 512 / h + boxes_scaled = boxes / out_stride + boxes_scaled[:, :2] = torch.floor(boxes_scaled[:, :2]) # y_tl, x_tl: floor + boxes_scaled[:, 2:] = torch.ceil(boxes_scaled[:, 2:]) # y_br, x_br: ceil + boxes_scaled[:, :2] = torch.clamp_min(boxes_scaled[:, :2], 0) + boxes_scaled[:, 2] = torch.clamp_max(boxes_scaled[:, 2], h) + boxes_scaled[:, 3] = torch.clamp_max(boxes_scaled[:, 3], w) + feat_boxes = [] + for idx_box in range(0, boxes.shape[0]): + y_tl, x_tl, y_br, x_br = boxes_scaled[idx_box] + y_tl, x_tl, y_br, x_br = int(y_tl), int(x_tl), int(y_br), int(x_br) + feat_box = feat[:, :, y_tl : (y_br + 1), x_tl : (x_br + 1)] + feat_boxes.append(feat_box) + return feat_boxes + +class Counting_with_SD_features(nn.Module): + def __init__(self, scale_factor): + super(Counting_with_SD_features, self).__init__() + self.adapter = adapter_roi() + # self.regressor = regressor_with_SD_features() + +class Counting_with_SD_features_loca(nn.Module): + def __init__(self, scale_factor): + super(Counting_with_SD_features_loca, self).__init__() + self.adapter = adapter_roi_loca() + self.regressor = regressor_with_SD_features() + + +class Counting_with_SD_features_dino_vit_c3(nn.Module): + def __init__(self, scale_factor, vit=None): + super(Counting_with_SD_features_dino_vit_c3, self).__init__() + self.adapter = adapter_roi_loca() + self.regressor = regressor_with_SD_features_seg_vit_c3() + +class Counting_with_SD_features_track(nn.Module): + def __init__(self, scale_factor, vit=None): + super(Counting_with_SD_features_track, self).__init__() + self.adapter = adapter_roi_loca() + self.regressor = regressor_with_SD_features_tra() + + +class adapter_roi(nn.Module): + def __init__(self, pool_size=[3, 3]): + super(adapter_roi, self).__init__() + self.pool_size = pool_size + self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + # self.relu = nn.ReLU() + # self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.pool = nn.MaxPool2d(2) + self.fc = nn.Linear(256 * 3 * 3, 768) + # **new + self.fc1 = nn.Sequential( + nn.ReLU(), + nn.Linear(768, 768 // 4, bias=False), + nn.ReLU() + ) + self.fc2 = nn.Sequential( + nn.Linear(768 // 4, 768, bias=False), + # nn.ReLU() + ) + self.initialize_weights() + + def forward(self, x, boxes): + num_of_boxes = boxes.shape[1] + rois = [] + bs, _, h, w = x.shape + boxes = torch.cat([ + torch.arange( + bs, requires_grad=False + ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1), + boxes.flatten(0, 1), + ], dim=1) + rois = roi_align( + x, + boxes=boxes, output_size=3, + spatial_scale=1.0 / 8, aligned=True + ) + rois = torch.mean(rois, dim=0, keepdim=True) + x = self.conv1(rois) + x = x.view(x.size(0), -1) + x = self.fc(x) + + x = self.fc1(x) + x = self.fc2(x) + return x + + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class adapter_roi_loca(nn.Module): + def __init__(self, pool_size=[3, 3]): + super(adapter_roi_loca, self).__init__() + self.pool_size = pool_size + self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1) + self.pool = nn.MaxPool2d(2) + self.fc = nn.Linear(256 * 3 * 3, 768) + self.initialize_weights() + def forward(self, x, boxes): + num_of_boxes = boxes.shape[1] + rois = [] + bs, _, h, w = x.shape + if h != 512 or w != 512: + x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False) + if bs == 1: + boxes = torch.cat([ + torch.arange( + bs, requires_grad=False + ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1), + boxes.flatten(0, 1), + ], dim=1) + rois = roi_align( + x, + boxes=boxes, output_size=3, + spatial_scale=1.0 / 8, aligned=True + ) + rois = torch.mean(rois, dim=0, keepdim=True) + else: + boxes = torch.cat([ + boxes.flatten(0, 1), + ], dim=1).split(num_of_boxes, dim=0) + rois = roi_align( + x, + boxes=boxes, output_size=3, + spatial_scale=1.0 / 8, aligned=True + ) + rois = rois.split(num_of_boxes, dim=0) + rois = torch.stack(rois, dim=0) + rois = torch.mean(rois, dim=1, keepdim=False) + x = self.conv1(rois) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + def forward_boxes(self, x, boxes): + num_of_boxes = boxes.shape[1] + rois = [] + bs, _, h, w = x.shape + if h != 512 or w != 512: + x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False) + if bs == 1: + boxes = torch.cat([ + torch.arange( + bs, requires_grad=False + ).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1), + boxes.flatten(0, 1), + ], dim=1) + rois = roi_align( + x, + boxes=boxes, output_size=3, + spatial_scale=1.0 / 8, aligned=True + ) + # rois = torch.mean(rois, dim=0, keepdim=True) + else: + raise NotImplementedError + x = self.conv1(rois) + x = x.view(x.size(0), -1) + x = self.fc(x) + return x + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + + + +class regressor1(nn.Module): + def __init__(self): + super(regressor1, self).__init__() + self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1) + self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1) + self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2) + self.leaky_relu = nn.LeakyReLU() + self.relu = nn.ReLU() + self.initialize_weights() + + + + def forward(self, x): + x_ = self.conv1(x) + x_ = self.leaky_relu(x_) + x_ = self.upsampler(x_) + x_ = self.conv2(x_) + x_ = self.leaky_relu(x_) + x_ = self.upsampler(x_) + x_ = self.conv3(x_) + x_ = self.relu(x_) + out = x_ + return out + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class regressor1(nn.Module): + def __init__(self): + super(regressor1, self).__init__() + self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1) + self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1) + self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2) + self.leaky_relu = nn.LeakyReLU() + self.relu = nn.ReLU() + + def forward(self, x): + x_ = self.conv1(x) + x_ = self.leaky_relu(x_) + x_ = self.upsampler(x_) + x_ = self.conv2(x_) + x_ = self.leaky_relu(x_) + x_ = self.upsampler(x_) + x_ = self.conv3(x_) + x_ = self.relu(x_) + out = x_ + return out + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class regressor_with_SD_features(nn.Module): + def __init__(self): + super(regressor_with_SD_features, self).__init__() + self.layer1 = nn.Sequential( + nn.Conv2d(324, 256, kernel_size=1, stride=1), + nn.LeakyReLU(), + nn.LayerNorm((64, 64)) + ) + self.layer2 = nn.Sequential( + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.LeakyReLU(), + nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1), + ) + self.layer3 = nn.Sequential( + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1), + ) + self.layer4 = nn.Sequential( + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.LeakyReLU(), + nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1), + ) + self.conv = nn.Sequential( + nn.Conv2d(32, 1, kernel_size=1), + nn.ReLU() + ) + self.norm = nn.LayerNorm(normalized_shape=(64, 64)) + self.initialize_weights() + + def forward(self, attn_stack, feature_list): + attn_stack = self.norm(attn_stack) + unet_feature = feature_list[-1] + attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) + unet_feature = unet_feature * attn_stack_mean + unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64] + x = self.layer1(unet_feature) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + out = self.conv(x) + return out / 100 + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + +class regressor_with_SD_features_seg(nn.Module): + def __init__(self): + super(regressor_with_SD_features_seg, self).__init__() + self.layer1 = nn.Sequential( + nn.Conv2d(324, 256, kernel_size=1, stride=1), + nn.LeakyReLU(), + nn.LayerNorm((64, 64)) + ) + self.layer2 = nn.Sequential( + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.LeakyReLU(), + nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1), + ) + self.layer3 = nn.Sequential( + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1), + ) + self.layer4 = nn.Sequential( + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.LeakyReLU(), + nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1), + ) + self.conv = nn.Sequential( + nn.Conv2d(32, 2, kernel_size=1), + # nn.ReLU() + ) + self.norm = nn.LayerNorm(normalized_shape=(64, 64)) + self.initialize_weights() + + def forward(self, attn_stack, feature_list): + attn_stack = self.norm(attn_stack) + unet_feature = feature_list[-1] + attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) + unet_feature = unet_feature * attn_stack_mean + unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64] + x = self.layer1(unet_feature) + x = self.layer2(x) + x = self.layer3(x) + x = self.layer4(x) + out = self.conv(x) + return out + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +from models.enc_model.unet_parts import * + + +class regressor_with_SD_features_seg_vit_c3(nn.Module): + def __init__(self, n_channels=3, n_classes=2, bilinear=False): + super(regressor_with_SD_features_seg_vit_c3, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + self.norm = nn.LayerNorm(normalized_shape=(64, 64)) + self.inc_0 = nn.Conv2d(n_channels, 3, kernel_size=3, padding=1) + self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False) + self.vit = self.vit_model.net + + def forward(self, img, attn_stack, feature_list): + attn_stack = attn_stack[:, [1,3], ...] + attn_stack = self.norm(attn_stack) + unet_feature = feature_list[-1] + unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True) + + x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64] + + if x.shape[-1] != 512: + x = F.interpolate(x, size=(512, 512), mode="bilinear") + x = self.inc_0(x) + + + + out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0] + if out.dtype == np.uint16: + out = out.astype(np.int16) + out = torch.from_numpy(out).unsqueeze(0).to(x.device) + return out + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + +class regressor_with_SD_features_tra(nn.Module): + def __init__(self, n_channels=2, n_classes=2, bilinear=False): + super(regressor_with_SD_features_tra, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + self.norm = nn.LayerNorm(normalized_shape=(64, 64)) + + # segmentation + self.inc_0 = nn.Conv2d(3, 3, kernel_size=3, padding=1) + self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False) + self.vit = self.vit_model.net + + self.inc_1 = nn.Conv2d(n_channels, 1, kernel_size=3, padding=1) + self.mlp = nn.Linear(64 * 64, 320) + # self.vit = self.vit_model.net.float() + + def forward_seg(self, img, attn_stack, feature_list, mask, training=False): + attn_stack = attn_stack[:, [1,3], ...] + attn_stack = self.norm(attn_stack) + unet_feature = feature_list[-1] + unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True) + x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64] + + if x.shape[-1] != 512: + x = F.interpolate(x, size=(512, 512), mode="bilinear") + x = self.inc_0(x) + feat = x + + out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0] + if out.dtype == np.uint16: + out = out.astype(np.int16) + out = torch.from_numpy(out).unsqueeze(0).to(x.device) + return out, 0., feat + + def forward(self, attn_prev, feature_list_prev, attn_after, feature_list_after): + assert attn_prev.shape == attn_after.shape, "attn_prev and attn_after must have the same shape" + n_instances = attn_prev.shape[0] + attn_prev = self.norm(attn_prev) # [n_instances, 1, 64, 64] + attn_after = self.norm(attn_after) + + x = torch.cat([attn_prev, attn_after], dim=1) # n_instances, 2, 64, 64 + + x = self.inc_1(x) + x = x.view(1, n_instances, -1) # Flatten the tensor to [n_instances, 64*64*4] + x = self.mlp(x) # Apply the MLP to get the output + + return x # Output shape will be [n_instances, 4] + + + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + + +class regressor_with_SD_features_inst_seg_unet(nn.Module): + def __init__(self, n_channels=8, n_classes=3, bilinear=False): + super(regressor_with_SD_features_inst_seg_unet, self).__init__() + self.n_channels = n_channels + self.n_classes = n_classes + self.bilinear = bilinear + self.norm = nn.LayerNorm(normalized_shape=(64, 64)) + self.inc_0 = (DoubleConv(n_channels, 3)) + self.inc = (DoubleConv(3, 64)) + self.down1 = (Down(64, 128)) + self.down2 = (Down(128, 256)) + self.down3 = (Down(256, 512)) + factor = 2 if bilinear else 1 + self.down4 = (Down(512, 1024 // factor)) + self.up1 = (Up(1024, 512 // factor, bilinear)) + self.up2 = (Up(512, 256 // factor, bilinear)) + self.up3 = (Up(256, 128 // factor, bilinear)) + self.up4 = (Up(128, 64, bilinear)) + self.outc = (OutConv(64, n_classes)) + + def forward(self, img, attn_stack, feature_list): + attn_stack = self.norm(attn_stack) + unet_feature = feature_list[-1] + unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True) + attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) + unet_feature_mean = unet_feature_mean * attn_stack_mean + x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64] + if x.shape[-1] != 512: + x = F.interpolate(x, size=(512, 512), mode="bilinear") + x = torch.cat([img, x], dim=1) # [1, 8, 512, 512] + x = self.inc_0(x) + x1 = self.inc(x) + x2 = self.down1(x1) + x3 = self.down2(x2) + x4 = self.down3(x3) + x5 = self.down4(x4) + x = self.up1(x5, x4) + x = self.up2(x, x3) + x = self.up3(x, x2) + x = self.up4(x, x1) + out = self.outc(x) + return out + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class regressor_with_SD_features_self(nn.Module): + def __init__(self): + super(regressor_with_SD_features_self, self).__init__() + self.layer = nn.Sequential( + nn.Conv2d(4096, 1024, kernel_size=1, stride=1), + nn.LeakyReLU(), + nn.LayerNorm((64, 64)), + nn.Conv2d(1024, 256, kernel_size=1, stride=1), + nn.LeakyReLU(), + nn.LayerNorm((64, 64)), + ) + self.layer2 = nn.Sequential( + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.LeakyReLU(), + nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1), + ) + self.layer3 = nn.Sequential( + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1), + ) + self.layer4 = nn.Sequential( + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.LeakyReLU(), + nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1), + ) + self.conv = nn.Sequential( + nn.Conv2d(32, 1, kernel_size=1), + nn.ReLU() + ) + self.norm = nn.LayerNorm(normalized_shape=(64, 64)) + self.initialize_weights() + + def forward(self, self_attn): + self_attn = self_attn.permute(2, 0, 1) + self_attn = self.layer(self_attn) + return self_attn + # attn_stack = self.norm(attn_stack) + # unet_feature = feature_list[-1] + # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) + # unet_feature = unet_feature * attn_stack_mean + # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64] + # x = self.layer(unet_feature) + # x = self.layer2(x) + # x = self.layer3(x) + # x = self.layer4(x) + # out = self.conv(x) + # return out / 100 + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +class regressor_with_SD_features_latent(nn.Module): + def __init__(self): + super(regressor_with_SD_features_latent, self).__init__() + self.layer = nn.Sequential( + nn.Conv2d(4, 256, kernel_size=1, stride=1), + nn.LeakyReLU(), + nn.LayerNorm((64, 64)) + ) + self.layer2 = nn.Sequential( + nn.Conv2d(256, 128, kernel_size=3, padding=1), + nn.LeakyReLU(), + nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1), + ) + self.layer3 = nn.Sequential( + nn.Conv2d(128, 64, kernel_size=3, padding=1), + nn.ReLU(), + nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1), + ) + self.layer4 = nn.Sequential( + nn.Conv2d(64, 32, kernel_size=3, padding=1), + nn.LeakyReLU(), + nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1), + ) + self.conv = nn.Sequential( + nn.Conv2d(32, 1, kernel_size=1), + nn.ReLU() + ) + self.norm = nn.LayerNorm(normalized_shape=(64, 64)) + self.initialize_weights() + + def forward(self, self_attn): + # self_attn = self_attn.permute(2, 0, 1) + self_attn = self.layer(self_attn) + return self_attn + # attn_stack = self.norm(attn_stack) + # unet_feature = feature_list[-1] + # attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True) + # unet_feature = unet_feature * attn_stack_mean + # unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64] + # x = self.layer(unet_feature) + # x = self.layer2(x) + # x = self.layer3(x) + # x = self.layer4(x) + # out = self.conv(x) + # return out / 100 + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + + + + +class regressor_with_deconv(nn.Module): + def __init__(self): + super(regressor_with_deconv, self).__init__() + self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1) + self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1) + self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1) + self.deconv1 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1) + self.deconv2 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1) + self.leaky_relu = nn.LeakyReLU() + self.relu = nn.ReLU() + self.initialize_weights() + + def forward(self, x): + x_ = self.conv1(x) + x_ = self.leaky_relu(x_) + x_ = self.deconv1(x_) + x_ = self.conv2(x_) + x_ = self.leaky_relu(x_) + x_ = self.deconv2(x_) + x_ = self.conv3(x_) + x_ = self.relu(x_) + out = x_ + return out + + def initialize_weights(self): + for m in self.modules(): + if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear): + nn.init.xavier_normal_(m.weight) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + + diff --git a/models/seg_post_model/cellpose/__init__.py b/models/seg_post_model/cellpose/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd47697fc85ad77549ec737cb6c17ea48363696 --- /dev/null +++ b/models/seg_post_model/cellpose/__init__.py @@ -0,0 +1 @@ +from .version import version, version_str diff --git a/models/seg_post_model/cellpose/__main__.py b/models/seg_post_model/cellpose/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..f50d2497cbe9b349028aee8646dca6a2646ebccb --- /dev/null +++ b/models/seg_post_model/cellpose/__main__.py @@ -0,0 +1,272 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +import os, time +import numpy as np +from tqdm import tqdm +from cellpose import utils, models, io, train +from .version import version_str +from cellpose.cli import get_arg_parser + +try: + from cellpose.gui import gui3d, gui + GUI_ENABLED = True +except ImportError as err: + GUI_ERROR = err + GUI_ENABLED = False + GUI_IMPORT = True +except Exception as err: + GUI_ENABLED = False + GUI_ERROR = err + GUI_IMPORT = False + raise + +import logging + + +def main(): + """ Run cellpose from command line + """ + + args = get_arg_parser().parse_args() # this has to be in a separate file for autodoc to work + + if args.version: + print(version_str) + return + + ######## if no image arguments are provided, run GUI or add model and exit ######## + if len(args.dir) == 0 and len(args.image_path) == 0: + if args.add_model: + io.add_model(args.add_model) + return + else: + if not GUI_ENABLED: + print("GUI ERROR: %s" % GUI_ERROR) + if GUI_IMPORT: + print( + "GUI FAILED: GUI dependencies may not be installed, to install, run" + ) + print(" pip install 'cellpose[gui]'") + else: + if args.Zstack: + gui3d.run() + else: + gui.run() + return + + ############################## run cellpose on images ############################## + if args.verbose: + from .io import logger_setup + logger, log_file = logger_setup() + else: + print( + ">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose") + print("No --verbose => no progress or info printed") + logger = logging.getLogger(__name__) + + + # find images + if len(args.img_filter) > 0: + image_filter = args.img_filter + else: + image_filter = None + + device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu, + device=args.gpu_device) + + if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0": + pretrained_model = "cpsam" + logger.warning("training from scratch is disabled, using 'cpsam' model") + else: + pretrained_model = args.pretrained_model + + # Warn users about old arguments from CP3: + if args.pretrained_model_ortho: + logger.warning( + "the '--pretrained_model_ortho' flag is deprecated in v4.0.1+ and no longer used") + if args.train_size: + logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used") + if args.chan or args.chan2: + logger.warning('--chan and --chan2 are deprecated, all channels are used by default') + if args.all_channels: + logger.warning("the '--all_channels' flag is deprecated in v4.0.1+ and no longer used") + if args.restore_type: + logger.warning("the '--restore_type' flag is deprecated in v4.0.1+ and no longer used") + if args.transformer: + logger.warning("the '--tranformer' flag is deprecated in v4.0.1+ and no longer used") + if args.invert: + logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used") + if args.chan2_restore: + logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used") + if args.diam_mean: + logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used") + if args.train_size: + logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used") + + if args.norm_percentile is not None: + value1, value2 = args.norm_percentile + normalize = {'percentile': (float(value1), float(value2))} + else: + normalize = (not args.no_norm) + + if args.save_each: + if not args.save_every: + raise ValueError("ERROR: --save_each requires --save_every") + + if len(args.image_path) > 0 and args.train: + raise ValueError("ERROR: cannot train model with single image input") + + ## Run evaluation on images + if not args.train: + _evaluate_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize) + + ## Train a model ## + else: + _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize) + + +def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize): + test_dir = None if len(args.test_dir) == 0 else args.test_dir + images, labels, image_names, train_probs = None, None, None, None + test_images, test_labels, image_names_test, test_probs = None, None, None, None + compute_flows = False + if len(args.file_list) > 0: + if os.path.exists(args.file_list): + dat = np.load(args.file_list, allow_pickle=True).item() + image_names = dat["train_files"] + image_names_test = dat.get("test_files", None) + train_probs = dat.get("train_probs", None) + test_probs = dat.get("test_probs", None) + compute_flows = dat.get("compute_flows", False) + load_files = False + else: + logger.critical(f"ERROR: {args.file_list} does not exist") + else: + output = io.load_train_test_data(args.dir, test_dir, image_filter, + args.mask_filter, + args.look_one_level_down) + images, labels, image_names, test_images, test_labels, image_names_test = output + load_files = True + + # initialize model + model = models.CellposeModel(device=device, pretrained_model=pretrained_model) + + # train segmentation model + cpmodel_path = train.train_seg( + model.net, images, labels, train_files=image_names, + test_data=test_images, test_labels=test_labels, + test_files=image_names_test, train_probs=train_probs, + test_probs=test_probs, compute_flows=compute_flows, + load_files=load_files, normalize=normalize, + channel_axis=args.channel_axis, + learning_rate=args.learning_rate, weight_decay=args.weight_decay, + SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.train_batch_size, + min_train_masks=args.min_train_masks, + nimg_per_epoch=args.nimg_per_epoch, + nimg_test_per_epoch=args.nimg_test_per_epoch, + save_path=os.path.realpath(args.dir), + save_every=args.save_every, + save_each=args.save_each, + model_name=args.model_name_out)[0] + model.pretrained_model = cpmodel_path + logger.info(">>>> model trained and saved to %s" % cpmodel_path) + return model + + +def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, normalize): + # Check with user if they REALLY mean to run without saving anything + if not args.train: + saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt + + tic = time.time() + if len(args.dir) > 0: + image_names = io.get_image_files( + args.dir, args.mask_filter, imf=imf, + look_one_level_down=args.look_one_level_down) + else: + if os.path.exists(args.image_path): + image_names = [args.image_path] + else: + raise ValueError(f"ERROR: no file found at {args.image_path}") + nimg = len(image_names) + + if args.savedir: + if not os.path.exists(args.savedir): + raise FileExistsError(f"--savedir {args.savedir} does not exist") + + logger.info( + ">>>> running cellpose on %d images using all channels" % nimg) + + # handle built-in model exceptions + model = models.CellposeModel(device=device, pretrained_model=pretrained_model,) + + tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO) + + channel_axis = args.channel_axis + z_axis = args.z_axis + + for image_name in tqdm(image_names, file=tqdm_out): + if args.do_3D or args.stitch_threshold > 0.: + logger.info('loading image as 3D zstack') + image = io.imread_3D(image_name) + if channel_axis is None: + channel_axis = 3 + if z_axis is None: + z_axis = 0 + + else: + image = io.imread_2D(image_name) + out = model.eval( + image, + diameter=args.diameter, + do_3D=args.do_3D, + augment=args.augment, + flow_threshold=args.flow_threshold, + cellprob_threshold=args.cellprob_threshold, + stitch_threshold=args.stitch_threshold, + min_size=args.min_size, + batch_size=args.batch_size, + bsize=args.bsize, + resample=not args.no_resample, + normalize=normalize, + channel_axis=channel_axis, + z_axis=z_axis, + anisotropy=args.anisotropy, + niter=args.niter, + flow3D_smooth=args.flow3D_smooth) + masks, flows = out[:2] + + if args.exclude_on_edges: + masks = utils.remove_edge_masks(masks) + if not args.no_npy: + io.masks_flows_to_seg(image, masks, flows, image_name, + imgs_restore=None, + restore_type=None, + ratio=1.) + if saving_something: + suffix = "_cp_masks" + if args.output_name is not None: + # (1) If `savedir` is not defined, then must have a non-zero `suffix` + if args.savedir is None and len(args.output_name) > 0: + suffix = args.output_name + elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir): + # (2) If `savedir` is defined, and different from `dir` then + # takes the value passed as a param. (which can be empty string) + suffix = args.output_name + + io.save_masks(image, masks, flows, image_name, + suffix=suffix, png=args.save_png, + tif=args.save_tif, save_flows=args.save_flows, + save_outlines=args.save_outlines, + dir_above=args.dir_above, savedir=args.savedir, + save_txt=args.save_txt, in_folders=args.in_folders, + save_mpl=args.save_mpl) + if args.save_rois: + io.save_rois(masks, image_name) + logger.info(">>>> completed in %0.3f sec" % (time.time() - tic)) + + return model + + +if __name__ == "__main__": + main() diff --git a/models/seg_post_model/cellpose/cli.py b/models/seg_post_model/cellpose/cli.py new file mode 100644 index 0000000000000000000000000000000000000000..9056007170785c1a2069963b28c129a63cac1b10 --- /dev/null +++ b/models/seg_post_model/cellpose/cli.py @@ -0,0 +1,240 @@ +""" +Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden. +""" + +import argparse + + +def get_arg_parser(): + """ Parses command line arguments for cellpose main function + + Note: this function has to be in a separate file to allow autodoc to work for CLI. + The autodoc_mock_imports in conf.py does not work for sphinx-argparse sometimes, + see https://github.com/ashb/sphinx-argparse/issues/9#issue-1097057823 + """ + + parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters") + + # misc settings + parser.add_argument("--version", action="store_true", + help="show cellpose version info") + parser.add_argument( + "--verbose", action="store_true", + help="show information about running and settings and save to log") + parser.add_argument("--Zstack", action="store_true", help="run GUI in 3D mode") + + # settings for CPU vs GPU + hardware_args = parser.add_argument_group("Hardware Arguments") + hardware_args.add_argument("--use_gpu", action="store_true", + help="use gpu if torch with cuda installed") + hardware_args.add_argument( + "--gpu_device", required=False, default="0", type=str, + help="which gpu device to use, use an integer for torch, or mps for M1") + + # settings for locating and formatting images + input_img_args = parser.add_argument_group("Input Image Arguments") + input_img_args.add_argument("--dir", default=[], type=str, + help="folder containing data to run or train on.") + input_img_args.add_argument( + "--image_path", default=[], type=str, help= + "if given and --dir not given, run on single image instead of folder (cannot train with this option)" + ) + input_img_args.add_argument( + "--look_one_level_down", action="store_true", + help="run processing on all subdirectories of current folder") + input_img_args.add_argument("--img_filter", default=[], type=str, + help="end string for images to run on") + input_img_args.add_argument( + "--channel_axis", default=None, type=int, + help="axis of image which corresponds to image channels") + input_img_args.add_argument("--z_axis", default=None, type=int, + help="axis of image which corresponds to Z dimension") + + # TODO: remove deprecated in future version + input_img_args.add_argument( + "--chan", default=0, type=int, help= + "Deprecated in v4.0.1+, not used. ") + input_img_args.add_argument( + "--chan2", default=0, type=int, help= + 'Deprecated in v4.0.1+, not used. ') + input_img_args.add_argument("--invert", action="store_true", help= + 'Deprecated in v4.0.1+, not used. ') + input_img_args.add_argument( + "--all_channels", action="store_true", help= + 'Deprecated in v4.0.1+, not used. ') + + # model settings + model_args = parser.add_argument_group("Model Arguments") + model_args.add_argument("--pretrained_model", required=False, default="cpsam", + type=str, + help="model to use for running or starting training") + model_args.add_argument( + "--add_model", required=False, default=None, type=str, + help="model path to copy model to hidden .cellpose folder for using in GUI/CLI") + model_args.add_argument("--pretrained_model_ortho", required=False, default=None, + type=str, + help="Deprecated in v4.0.1+, not used. ") + + # TODO: remove deprecated in future version + model_args.add_argument("--restore_type", required=False, default=None, type=str, help= + 'Deprecated in v4.0.1+, not used. ') + model_args.add_argument("--chan2_restore", action="store_true", help= + 'Deprecated in v4.0.1+, not used. ') + model_args.add_argument( + "--transformer", action="store_true", help= + "use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)") + + # algorithm settings + algorithm_args = parser.add_argument_group("Algorithm Arguments") + algorithm_args.add_argument("--no_norm", action="store_true", + help="do not normalize images (normalize=False)") + algorithm_args.add_argument( + '--norm_percentile', + nargs=2, # Require exactly two values + metavar=('VALUE1', 'VALUE2'), + help="Provide two float values to set norm_percentile (e.g., --norm_percentile 1 99)" + ) + algorithm_args.add_argument( + "--do_3D", action="store_true", + help="process images as 3D stacks of images (nplanes x nchan x Ly x Lx") + algorithm_args.add_argument( + "--diameter", required=False, default=None, type=float, help= + "use to resize cells to the training diameter (30 pixels)" + ) + algorithm_args.add_argument( + "--stitch_threshold", required=False, default=0.0, type=float, + help="compute masks in 2D then stitch together masks with IoU>0.9 across planes" + ) + algorithm_args.add_argument( + "--min_size", required=False, default=15, type=int, + help="minimum number of pixels per mask, can turn off with -1") + algorithm_args.add_argument( + "--flow3D_smooth", required=False, default=0, type=float, + help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing") + algorithm_args.add_argument( + "--flow_threshold", default=0.4, type=float, help= + "flow error threshold, 0 turns off this optional QC step. Default: %(default)s") + algorithm_args.add_argument( + "--cellprob_threshold", default=0, type=float, + help="cellprob threshold, default is 0, decrease to find more and larger masks") + algorithm_args.add_argument( + "--niter", default=0, type=int, help= + "niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs" + ) + algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float, + help="anisotropy of volume in 3D") + algorithm_args.add_argument("--exclude_on_edges", action="store_true", + help="discard masks which touch edges of image") + algorithm_args.add_argument( + "--augment", action="store_true", + help="tiles image with overlapping tiles and flips overlapped regions to augment" + ) + algorithm_args.add_argument("--batch_size", default=8, type=int, + help="inference batch size. Default: %(default)s") + + # TODO: remove deprecated in future version + algorithm_args.add_argument( + "--no_resample", action="store_true", + help="disables flows/cellprob resampling to original image size before computing masks. Using this flag will make more masks more jagged with larger diameter settings.") + algorithm_args.add_argument( + "--no_interp", action="store_true", + help="do not interpolate when running dynamics (was default)") + + # output settings + output_args = parser.add_argument_group("Output Arguments") + output_args.add_argument( + "--save_png", action="store_true", + help="save masks as png") + output_args.add_argument( + "--save_tif", action="store_true", + help="save masks as tif") + output_args.add_argument( + "--output_name", default=None, type=str, + help="suffix for saved masks, default is _cp_masks, can be empty if `savedir` used and different of `dir`") + output_args.add_argument("--no_npy", action="store_true", + help="suppress saving of npy") + output_args.add_argument( + "--savedir", default=None, type=str, help= + "folder to which segmentation results will be saved (defaults to input image directory)" + ) + output_args.add_argument( + "--dir_above", action="store_true", help= + "save output folders adjacent to image folder instead of inside it (off by default)" + ) + output_args.add_argument("--in_folders", action="store_true", + help="flag to save output in folders (off by default)") + output_args.add_argument( + "--save_flows", action="store_true", help= + "whether or not to save RGB images of flows when masks are saved (disabled by default)" + ) + output_args.add_argument( + "--save_outlines", action="store_true", help= + "whether or not to save RGB outline images when masks are saved (disabled by default)" + ) + output_args.add_argument( + "--save_rois", action="store_true", + help="whether or not to save ImageJ compatible ROI archive (disabled by default)" + ) + output_args.add_argument( + "--save_txt", action="store_true", + help="flag to enable txt outlines for ImageJ (disabled by default)") + output_args.add_argument( + "--save_mpl", action="store_true", + help="save a figure of image/mask/flows using matplotlib (disabled by default). " + "This is slow, especially with large images.") + + # training settings + training_args = parser.add_argument_group("Training Arguments") + training_args.add_argument("--train", action="store_true", + help="train network using images in dir") + training_args.add_argument("--test_dir", default=[], type=str, + help="folder containing test data (optional)") + training_args.add_argument( + "--file_list", default=[], type=str, help= + "path to list of files for training and testing and probabilities for each image (optional)" + ) + training_args.add_argument( + "--mask_filter", default="_masks", type=str, help= + "end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s" + ) + training_args.add_argument("--learning_rate", default=1e-5, type=float, + help="learning rate. Default: %(default)s") + training_args.add_argument("--weight_decay", default=0.1, type=float, + help="weight decay. Default: %(default)s") + training_args.add_argument("--n_epochs", default=100, type=int, + help="number of epochs. Default: %(default)s") + training_args.add_argument("--train_batch_size", default=1, type=int, + help="training batch size. Default: %(default)s") + training_args.add_argument("--bsize", default=256, type=int, + help="block size for tiles. Default: %(default)s") + training_args.add_argument( + "--nimg_per_epoch", default=None, type=int, + help="number of train images per epoch. Default is to use all train images.") + training_args.add_argument( + "--nimg_test_per_epoch", default=None, type=int, + help="number of test images per epoch. Default is to use all test images.") + training_args.add_argument( + "--min_train_masks", default=5, type=int, help= + "minimum number of masks a training image must have to be used. Default: %(default)s" + ) + training_args.add_argument("--SGD", default=0, type=int, + help="Deprecated in v4.0.1+, not used - AdamW used instead. ") + training_args.add_argument( + "--save_every", default=100, type=int, + help="number of epochs to skip between saves. Default: %(default)s") + training_args.add_argument( + "--save_each", action="store_true", + help="wether or not to save each epoch. Must also use --save_every. (default: False)") + training_args.add_argument( + "--model_name_out", default=None, type=str, + help="Name of model to save as, defaults to name describing model architecture. " + "Model is saved in the folder specified by --dir in models subfolder.") + + # TODO: remove deprecated in future version + training_args.add_argument( + "--diam_mean", default=30., type=float, help= + 'Deprecated in v4.0.1+, not used. ') + training_args.add_argument("--train_size", action="store_true", help= + 'Deprecated in v4.0.1+, not used. ') + + return parser diff --git a/models/seg_post_model/cellpose/core.py b/models/seg_post_model/cellpose/core.py new file mode 100644 index 0000000000000000000000000000000000000000..93ec31e02aa9e58a87edcabd585c00b466ac84d1 --- /dev/null +++ b/models/seg_post_model/cellpose/core.py @@ -0,0 +1,322 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +import logging +import numpy as np +from tqdm import trange +from . import transforms, utils + +import torch + +TORCH_ENABLED = True + +core_logger = logging.getLogger(__name__) +tqdm_out = utils.TqdmToLogger(core_logger, level=logging.INFO) + + +def use_gpu(gpu_number=0, use_torch=True): + """ + Check if GPU is available for use. + + Args: + gpu_number (int): The index of the GPU to be used. Default is 0. + use_torch (bool): Whether to use PyTorch for GPU check. Default is True. + + Returns: + bool: True if GPU is available, False otherwise. + + Raises: + ValueError: If use_torch is False, as cellpose only runs with PyTorch now. + """ + if use_torch: + return _use_gpu_torch(gpu_number) + else: + raise ValueError("cellpose only runs with PyTorch now") + + +def _use_gpu_torch(gpu_number=0): + """ + Checks if CUDA or MPS is available and working with PyTorch. + + Args: + gpu_number (int): The GPU device number to use (default is 0). + + Returns: + bool: True if CUDA or MPS is available and working, False otherwise. + """ + try: + device = torch.device("cuda:" + str(gpu_number)) + _ = torch.zeros((1,1)).to(device) + core_logger.info("** TORCH CUDA version installed and working. **") + return True + except: + pass + try: + device = torch.device('mps:' + str(gpu_number)) + _ = torch.zeros((1,1)).to(device) + core_logger.info('** TORCH MPS version installed and working. **') + return True + except: + core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.') + return False + + +def assign_device(use_torch=True, gpu=False, device=0): + """ + Assigns the device (CPU or GPU or mps) to be used for computation. + + Args: + use_torch (bool, optional): Whether to use torch for GPU detection. Defaults to True. + gpu (bool, optional): Whether to use GPU for computation. Defaults to False. + device (int or str, optional): The device index or name to be used. Defaults to 0. + + Returns: + torch.device, bool (True if GPU is used, False otherwise) + """ + + if isinstance(device, str): + if device != "mps" or not(gpu and torch.backends.mps.is_available()): + device = int(device) + if gpu and use_gpu(use_torch=True): + try: + if torch.cuda.is_available(): + device = torch.device(f'cuda:{device}') + core_logger.info(">>>> using GPU (CUDA)") + gpu = True + cpu = False + except: + gpu = False + cpu = True + try: + if torch.backends.mps.is_available(): + device = torch.device('mps') + core_logger.info(">>>> using GPU (MPS)") + gpu = True + cpu = False + except: + gpu = False + cpu = True + else: + device = torch.device('cpu') + core_logger.info('>>>> using CPU') + gpu = False + cpu = True + + if cpu: + device = torch.device("cpu") + core_logger.info(">>>> using CPU") + gpu = False + return device, gpu + + +def _to_device(x, device, dtype=torch.float32): + """ + Converts the input tensor or numpy array to the specified device. + + Args: + x (torch.Tensor or numpy.ndarray): The input tensor or numpy array. + device (torch.device): The target device. + + Returns: + torch.Tensor: The converted tensor on the specified device. + """ + if not isinstance(x, torch.Tensor): + X = torch.from_numpy(x).to(device, dtype=dtype) + return X + else: + return x + + +def _from_device(X): + """ + Converts a PyTorch tensor from the device to a NumPy array on the CPU. + + Args: + X (torch.Tensor): The input PyTorch tensor. + + Returns: + numpy.ndarray: The converted NumPy array. + """ + # The cast is so numpy conversion always works + x = X.detach().cpu().to(torch.float32).numpy() + return x + + +def _forward(net, x, feat=None): + """Converts images to torch tensors, runs the network model, and returns numpy arrays. + + Args: + net (torch.nn.Module): The network model. + x (numpy.ndarray): The input images. + + Returns: + Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features. + """ + X = _to_device(x, device=net.device, dtype=net.dtype) + if feat is not None: + feat = _to_device(feat, device=net.device, dtype=net.dtype) + net.eval() + with torch.no_grad(): + y, style = net(X, feat=feat)[:2] + del X + y = _from_device(y) + style = _from_device(style) + return y, style + + +def run_net(net, imgi, feat=None, batch_size=8, augment=False, tile_overlap=0.1, bsize=224, + rsz=None): + """ + Run network on stack of images. + + (faster if augment is False) + + Args: + net (class): cellpose network (model.net) + imgi (np.ndarray): The input image or stack of images of size [Lz x Ly x Lx x nchan]. + batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8. + rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0. + augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False. + tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1. + bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224. + + Returns: + Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3]. + y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability. + style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles. + """ + # run network + Lz, Ly0, Lx0, nchan = imgi.shape + if rsz is not None: + if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray): + rsz = [rsz, rsz] + Lyr, Lxr = int(Ly0 * rsz[0]), int(Lx0 * rsz[1]) + else: + Lyr, Lxr = Ly0, Lx0 # 512, 512 + + ly, lx = bsize, bsize # 256, 256 + ypad1, ypad2, xpad1, xpad2 = transforms.get_pad_yx(Lyr, Lxr, min_size=(bsize, bsize)) # 8 + Ly, Lx = Lyr + ypad1 + ypad2, Lxr + xpad1 + xpad2 # 528, 528 + pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]]) + + if augment: + ny = max(2, int(np.ceil(2. * Ly / bsize))) + nx = max(2, int(np.ceil(2. * Lx / bsize))) + else: + ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize)) # 3 + nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize)) # 3 + + + # run multiple slices at the same time + ntiles = ny * nx + nimgs = max(1, batch_size // ntiles) # number of imgs to run in the same batch, 1 + niter = int(np.ceil(Lz / nimgs)) # 1 + ziterator = (trange(niter, file=tqdm_out, mininterval=30) + if niter > 10 or Lz > 1 else range(niter)) + for k in ziterator: + inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs)) + IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 3, 256, 256 + if feat is not None: + FEATa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 256 + else: + FEATa = None + for i, b in enumerate(inds): + # pad image for net so Ly and Lx are divisible by 4 + imgb = transforms.resize_image(imgi[b], rsz=rsz) if rsz is not None else imgi[b].copy() + imgb = np.pad(imgb.transpose(2,0,1), pads, mode="constant") # 3, 528, 528 + + IMG, ysub, xsub, Lyt, Lxt = transforms.make_tiles( + imgb, bsize=bsize, augment=augment, + tile_overlap=tile_overlap) # IMG: 3, 3, 3, 256, 256 + IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG, + (ny * nx, nchan, ly, lx)) + if feat is not None: + featb = transforms.resize_image(feat[b], rsz=rsz) if rsz is not None else feat[b].copy() + featb = np.pad(featb.transpose(2,0,1), pads, mode="constant") + FEAT, ysub, xsub, Lyt, Lxt = transforms.make_tiles( + featb, bsize=bsize, augment=augment, + tile_overlap=tile_overlap) + FEATa[i * ntiles : (i+1) * ntiles] = np.reshape(FEAT, + (ny * nx, nchan, ly, lx)) + + # run network + for j in range(0, IMGa.shape[0], batch_size): + bslc = slice(j, min(j + batch_size, IMGa.shape[0])) + ya0, stylea0 = _forward(net, IMGa[bslc], feat=FEATa[bslc] if FEATa is not None else None) + if j == 0: + nout = ya0.shape[1] + ya = np.zeros((IMGa.shape[0], nout, ly, lx), "float32") + stylea = np.zeros((IMGa.shape[0], 256), "float32") + ya[bslc] = ya0 + stylea[bslc] = stylea0 + + # average tiles + for i, b in enumerate(inds): + if i==0 and k==0: + yf = np.zeros((Lz, nout, Ly, Lx), "float32") + styles = np.zeros((Lz, 256), "float32") + y = ya[i * ntiles : (i + 1) * ntiles] + if augment: + y = np.reshape(y, (ny, nx, 3, ly, lx)) + y = transforms.unaugment_tiles(y) + y = np.reshape(y, (-1, 3, ly, lx)) + yfi = transforms.average_tiles(y, ysub, xsub, Lyt, Lxt) + yf[b] = yfi[:, :imgb.shape[-2], :imgb.shape[-1]] + stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0) + stylei /= (stylei**2).sum()**0.5 + styles[b] = stylei + # slices from padding + yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2] + yf = yf.transpose(0,2,3,1) + return yf, np.array(styles) + + +def run_3D(net, imgs, batch_size=8, augment=False, + tile_overlap=0.1, bsize=224, net_ortho=None, + progress=None): + """ + Run network on image z-stack. + + (faster if augment is False) + + Args: + imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan]. + batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8. + rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0. + anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None. + augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False. + tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1. + bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224. + net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None. + progress (QProgressBar, optional): pyqt progress bar. Defaults to None. + + Returns: + Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3]. + y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability. + style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles. + """ + sstr = ["YX", "ZY", "ZX"] + pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)] + ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)] + cp = [(1, 2), (0, 2), (0, 1)] + cpy = [(0, 1), (0, 1), (0, 1)] + shape = imgs.shape[:-1] + yf = np.zeros((*shape, 4), "float32") + for p in range(3): + xsl = imgs.transpose(pm[p]) + # per image + core_logger.info("running %s: %d planes of size (%d, %d)" % + (sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]])) + y, style = run_net(net, + xsl, batch_size=batch_size, augment=augment, + bsize=bsize, tile_overlap=tile_overlap, + rsz=None) + yf[..., -1] += y[..., -1].transpose(ipm[p]) + for j in range(2): + yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p]) + y = None; del y + + if progress is not None: + progress.setValue(25 + 15 * p) + + return yf, style diff --git a/models/seg_post_model/cellpose/denoise.py b/models/seg_post_model/cellpose/denoise.py new file mode 100644 index 0000000000000000000000000000000000000000..1b1c57fb1204f287db90995e096f4d0f5950b958 --- /dev/null +++ b/models/seg_post_model/cellpose/denoise.py @@ -0,0 +1,1474 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +import os, time, datetime +import numpy as np +from scipy.stats import mode +import cv2 +import torch +from torch import nn +from torch.nn.functional import conv2d, interpolate +from tqdm import trange +from pathlib import Path + +import logging + +denoise_logger = logging.getLogger(__name__) + +from cellpose import transforms, utils, io +from cellpose.core import run_net +from cellpose.models import CellposeModel, model_path, normalize_default, assign_device + +MODEL_NAMES = [] +for ctype in ["cyto3", "cyto2", "nuclei"]: + for ntype in ["denoise", "deblur", "upsample", "oneclick"]: + MODEL_NAMES.append(f"{ntype}_{ctype}") + if ctype != "cyto3": + for ltype in ["per", "seg", "rec"]: + MODEL_NAMES.append(f"{ntype}_{ltype}_{ctype}") + if ctype != "cyto3": + MODEL_NAMES.append(f"aniso_{ctype}") + +criterion = nn.MSELoss(reduction="mean") +criterion2 = nn.BCEWithLogitsLoss(reduction="mean") + + +def deterministic(seed=0): + """ set random seeds to create test data """ + import random + torch.manual_seed(seed) + torch.cuda.manual_seed(seed) + torch.cuda.manual_seed_all(seed) # if you are using multi-GPU. + np.random.seed(seed) # Numpy module. + random.seed(seed) # Python random module. + torch.manual_seed(seed) + torch.backends.cudnn.benchmark = False + torch.backends.cudnn.deterministic = True + + +def loss_fn_rec(lbl, y): + """ loss function between true labels lbl and prediction y """ + loss = 80. * criterion(y, lbl) + return loss + + +def loss_fn_seg(lbl, y): + """ loss function between true labels lbl and prediction y """ + veci = 5. * lbl[:, 1:] + lbl = (lbl[:, 0] > .5).float() + loss = criterion(y[:, :2], veci) + loss /= 2. + loss2 = criterion2(y[:, 2], lbl) + loss = loss + loss2 + return loss + + +def get_sigma(Tdown): + """ Calculates the correlation matrices across channels for the perceptual loss. + + Args: + Tdown (list): List of tensors output by each downsampling block of network. + + Returns: + list: List of correlations for each input tensor. + """ + Tnorm = [x - x.mean((-2, -1), keepdim=True) for x in Tdown] + Tnorm = [x / x.std((-2, -1), keepdim=True) for x in Tnorm] + Sigma = [ + torch.einsum("bnxy, bmxy -> bnm", x, x) / (x.shape[-2] * x.shape[-1]) + for x in Tnorm + ] + return Sigma + + +def imstats(X, net1): + """ + Calculates the image correlation matrices for the perceptual loss. + + Args: + X (torch.Tensor): Input image tensor. + net1: Cellpose net. + + Returns: + list: A list of tensors of correlation matrices. + """ + _, _, Tdown = net1(X) + Sigma = get_sigma(Tdown) + Sigma = [x.detach() for x in Sigma] + return Sigma + + +def loss_fn_per(img, net1, yl): + """ + Calculates the perceptual loss function for image restoration. + + Args: + img (torch.Tensor): Input image tensor (noisy/blurry/downsampled). + net1 (torch.nn.Module): Perceptual loss net (Cellpose segmentation net). + yl (torch.Tensor): Clean image tensor. + + Returns: + torch.Tensor: Mean perceptual loss. + """ + Sigma = imstats(img, net1) + sd = [x.std((1, 2)) + 1e-6 for x in Sigma] + Sigma_test = get_sigma(yl) + losses = torch.zeros(len(Sigma[0]), device=img.device) + for k in range(len(Sigma)): + losses = losses + (((Sigma_test[k] - Sigma[k])**2).mean((1, 2)) / sd[k]**2) + return losses.mean() + + +def test_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]): + """ + Calculates the test loss for image restoration tasks. + + Args: + net0 (torch.nn.Module): The image restoration network. + X (torch.Tensor): The input image tensor. + net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None. + img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None. + lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None. + lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.]. + + Returns: + tuple: A tuple containing the total loss and the perceptual loss. + """ + net0.eval() + if net1 is not None: + net1.eval() + loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device) + + with torch.no_grad(): + img_dn = net0(X)[0] + if lam[2] > 0.: + loss += lam[2] * loss_fn_rec(img, img_dn) + if lam[1] > 0. or lam[0] > 0.: + y, _, ydown = net1(img_dn) + if lam[1] > 0.: + loss += lam[1] * loss_fn_seg(lbl, y) + if lam[0] > 0.: + loss_per = loss_fn_per(img, net1, ydown) + loss += lam[0] * loss_per + return loss, loss_per + + +def train_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]): + """ + Calculates the train loss for image restoration tasks. + + Args: + net0 (torch.nn.Module): The image restoration network. + X (torch.Tensor): The input image tensor. + net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None. + img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None. + lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None. + lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.]. + + Returns: + tuple: A tuple containing the total loss and the perceptual loss. + """ + net0.train() + if net1 is not None: + net1.eval() + loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device) + + img_dn = net0(X)[0] + if lam[2] > 0.: + loss += lam[2] * loss_fn_rec(img, img_dn) + if lam[1] > 0. or lam[0] > 0.: + y, _, ydown = net1(img_dn) + if lam[1] > 0.: + loss += lam[1] * loss_fn_seg(lbl, y) + if lam[0] > 0.: + loss_per = loss_fn_per(img, net1, ydown) + loss += lam[0] * loss_per + return loss, loss_per + + +def img_norm(imgi): + """ + Normalizes the input image by subtracting the 1st percentile and dividing by the difference between the 99th and 1st percentiles. + + Args: + imgi (torch.Tensor): Input image tensor. + + Returns: + torch.Tensor: Normalized image tensor. + """ + shape = imgi.shape + imgi = imgi.reshape(imgi.shape[0], imgi.shape[1], -1) + perc = torch.quantile(imgi, torch.tensor([0.01, 0.99], device=imgi.device), dim=-1, + keepdim=True) + for k in range(imgi.shape[1]): + hask = (perc[1, :, k, 0] - perc[0, :, k, 0]) > 1e-3 + imgi[hask, k] -= perc[0, hask, k] + imgi[hask, k] /= (perc[1, hask, k] - perc[0, hask, k]) + imgi = imgi.reshape(shape) + return imgi + + +def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7, + ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None, + ds=None, uniform_blur=False, partial_blur=False): + """Adds noise to the input image. + + Args: + lbl (torch.Tensor): The input image tensor of shape (nimg, nchan, Ly, Lx). + alpha (float, optional): The shape parameter of the gamma distribution used for generating poisson noise. Defaults to 4. + beta (float, optional): The rate parameter of the gamma distribution used for generating poisson noise. Defaults to 0.7. + poisson (float, optional): The probability of adding poisson noise to the image. Defaults to 0.7. + blur (float, optional): The probability of adding gaussian blur to the image. Defaults to 0.7. + gblur (float, optional): The scale factor for the gaussian blur. Defaults to 1.0. + downsample (float, optional): The probability of downsampling the image. Defaults to 0.7. + ds_max (int, optional): The maximum downsampling factor. Defaults to 7. + diams (torch.Tensor, optional): The diameter of the objects in the image. Defaults to None. + pscale (torch.Tensor, optional): The scale factor for the poisson noise, instead of sampling. Defaults to None. + iso (bool, optional): Whether to use isotropic gaussian blur. Defaults to True. + sigma0 (torch.Tensor, optional): The standard deviation of the gaussian filter for the Y axis, instead of sampling. Defaults to None. + sigma1 (torch.Tensor, optional): The standard deviation of the gaussian filter for the X axis, instead of sampling. Defaults to None. + ds (torch.Tensor, optional): The downsampling factor for each image, instead of sampling. Defaults to None. + + Returns: + torch.Tensor: The noisy image tensor of the same shape as the input image. + """ + device = lbl.device + imgi = torch.zeros_like(lbl) + Ly, Lx = lbl.shape[-2:] + + diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device) + #ds0 = 1 if ds is None else ds.item() + ds = ds * torch.ones( + (len(lbl),), device=device, dtype=torch.long) if ds is not None else ds + + # downsample + ii = [] + idownsample = np.random.rand(len(lbl)) < downsample + if (ds is None and idownsample.sum() > 0.) or not iso: + ds = torch.ones(len(lbl), dtype=torch.long, device=device) + ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),), + device=device) + ii = torch.nonzero(ds > 1).flatten() + elif ds is not None and (ds > 1).sum(): + ii = torch.nonzero(ds > 1).flatten() + + # add gaussian blur + iblur = torch.rand(len(lbl), device=device) < blur + iblur[ii] = True + if iblur.sum() > 0: + if sigma0 is None: + if uniform_blur and iso: + xr = torch.rand(len(lbl), device=device) + if len(ii) > 0: + xr[ii] = ds[ii].float() / 2. / gblur + sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur]) + sigma1 = sigma0.clone() + elif not iso: + xr = torch.rand(len(lbl), device=device) + if len(ii) > 0: + xr[ii] = (ds[ii].float()) / gblur + xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35 + xr[ii] = torch.clip(xr[ii], 0.05, 1.5) + sigma0 = diams[iblur] / 30. * gblur * xr[iblur] + sigma1 = sigma0.clone() / 10. + else: + xrand = np.random.exponential(1, size=iblur.sum()) + xrand = np.clip(xrand * 0.5, 0.1, 1.0) + xrand *= gblur + sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to( + device) + sigma1 = sigma0.clone() + else: + sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device) + sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device) + + # create gaussian filter + xr = max(8, sigma0.max().long() * 2) + gfilt0 = torch.exp(-torch.arange(-xr + 1, xr, device=device)**2 / + (2 * sigma0.unsqueeze(-1)**2)) + gfilt0 /= gfilt0.sum(axis=-1, keepdims=True) + gfilt1 = torch.zeros_like(gfilt0) + gfilt1[sigma1 == sigma0] = gfilt0[sigma1 == sigma0] + gfilt1[sigma1 != sigma0] = torch.exp( + -torch.arange(-xr + 1, xr, device=device)**2 / + (2 * sigma1[sigma1 != sigma0].unsqueeze(-1)**2)) + gfilt1[sigma1 == 0] = 0. + gfilt1[sigma1 == 0, xr] = 1. + gfilt1 /= gfilt1.sum(axis=-1, keepdims=True) + gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1) + gfilt /= gfilt.sum(axis=(1, 2), keepdims=True) + + lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1), + padding=gfilt.shape[-1] // 2, + groups=gfilt.shape[0]).transpose(1, 0) + if partial_blur: + #yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100) + imgi[iblur] = lbl[iblur].clone() + Lxc = int(Lx * 0.85) + ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32), + torch.arange(0, Lxc, dtype=torch.float32), + indexing="ij") + mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2)) + mask -= mask.min() + mask /= mask.max() + lbl_blur_crop = lbl_blur[:, :, :, :Lxc] + imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask + + (1-mask) * imgi[iblur, :, :, :Lxc]) + else: + imgi[iblur] = lbl_blur + + imgi[~iblur] = lbl[~iblur] + + # apply downsample + for k in ii: + i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]] + imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear") + + # add poisson noise + ipoisson = np.random.rand(len(lbl)) < poisson + if ipoisson.sum() > 0: + if pscale is None: + pscale = torch.zeros(len(lbl)) + m = torch.distributions.gamma.Gamma(alpha, beta) + pscale = torch.clamp(m.rsample(sample_shape=(ipoisson.sum(),)), 1.) + #pscale = torch.clamp(20 * (torch.rand(size=(len(lbl),), device=lbl.device)), 1.5) + pscale = pscale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device) + else: + pscale = pscale * torch.ones((ipoisson.sum(), 1, 1, 1), device=device) + imgi[ipoisson] = torch.poisson(pscale * imgi[ipoisson]) + imgi[~ipoisson] = imgi[~ipoisson] + + # renormalize + imgi = img_norm(imgi) + + return imgi + + +def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7, + downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30, + ds_max=7, uniform_blur=False, iso=True, rotate=True, + device=torch.device("cuda"), xy=(224, 224), + nchan_noise=1, keep_raw=True): + """ + Applies random rotation, resizing, and noise to the input data. + + Args: + data (numpy.ndarray): The input data. + labels (numpy.ndarray, optional): The flow and cellprob labels associated with the data. Defaults to None. + diams (float, optional): The diameter of the objects. Defaults to None. + poisson (float, optional): The Poisson noise probability. Defaults to 0.7. + blur (float, optional): The blur probability. Defaults to 0.7. + downsample (float, optional): The downsample probability. Defaults to 0.0. + beta (float, optional): The beta value for the poisson noise distribution. Defaults to 0.7. + gblur (float, optional): The Gaussian blur level. Defaults to 1.0. + diam_mean (float, optional): The mean diameter. Defaults to 30. + ds_max (int, optional): The maximum downsample value. Defaults to 7. + iso (bool, optional): Whether to apply isotropic augmentation. Defaults to True. + rotate (bool, optional): Whether to apply rotation augmentation. Defaults to True. + device (torch.device, optional): The device to use. Defaults to torch.device("cuda"). + xy (tuple, optional): The size of the output image. Defaults to (224, 224). + nchan_noise (int, optional): The number of channels to add noise to. Defaults to 1. + keep_raw (bool, optional): Whether to keep the raw image. Defaults to True. + + Returns: + torch.Tensor: The augmented image and augmented noisy/blurry/downsampled version of image. + torch.Tensor: The augmented labels. + float: The scale factor applied to the image. + """ + if device == None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + + diams = 30 if diams is None else diams + random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1)) + random_rsc = diams / random_diam #/ random_diam + #rsc /= random_scale + xy0 = (340, 340) + nchan = data[0].shape[0] + data_new = np.zeros((len(data), (1 + keep_raw) * nchan, xy0[0], xy0[1]), "float32") + labels_new = np.zeros((len(data), 3, xy0[0], xy0[1]), "float32") + for i in range( + len(data)): #, (sc, img, lbl) in enumerate(zip(random_rsc, data, labels)): + sc = random_rsc[i] + img = data[i] + lbl = labels[i] if labels is not None else None + # create affine transform to resize + Ly, Lx = img.shape[-2:] + dxy = np.maximum(0, np.array([Lx / sc - xy0[1], Ly / sc - xy0[0]])) + dxy = (np.random.rand(2,) - .5) * dxy + cc = np.array([Lx / 2, Ly / 2]) + cc1 = cc - np.array([Lx - xy0[1], Ly - xy0[0]]) / 2 + dxy + pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])]) + pts2 = np.float32( + [cc1, cc1 + np.array([1, 0]) / sc, cc1 + np.array([0, 1]) / sc]) + M = cv2.getAffineTransform(pts1, pts2) + + # apply to image + for c in range(nchan): + img_rsz = cv2.warpAffine(img[c], M, xy0, flags=cv2.INTER_LINEAR) + #img_noise = add_noise(torch.from_numpy(img_rsz).to(device).unsqueeze(0)).cpu().numpy().squeeze(0) + data_new[i, c] = img_rsz + if keep_raw: + data_new[i, c + nchan] = img_rsz + + if lbl is not None: + # apply to labels + labels_new[i, 0] = cv2.warpAffine(lbl[0], M, xy0, flags=cv2.INTER_NEAREST) + labels_new[i, 1] = cv2.warpAffine(lbl[1], M, xy0, flags=cv2.INTER_LINEAR) + labels_new[i, 2] = cv2.warpAffine(lbl[2], M, xy0, flags=cv2.INTER_LINEAR) + + rsc = random_diam / diam_mean + + # add noise before augmentations + img = torch.from_numpy(data_new).to(device) + img = torch.clamp(img, 0.) + # just add noise to cyto if nchan_noise=1 + img[:, :nchan_noise] = add_noise( + img[:, :nchan_noise], poisson=poisson, blur=blur, ds_max=ds_max, iso=iso, + downsample=downsample, beta=beta, gblur=gblur, + diams=torch.from_numpy(random_diam).to(device).float()) + # img -= img.mean(dim=(-2,-1), keepdim=True) + # img /= img.std(dim=(-2,-1), keepdim=True) + 1e-3 + img = img.cpu().numpy() + + # augmentations + img, lbl, scale = transforms.random_rotate_and_resize( + img, + Y=labels_new, + xy=xy, + rotate=False if not iso else rotate, + #(iso and downsample==0), + rescale=rsc, + scale_range=0.5) + img = torch.from_numpy(img).to(device) + lbl = torch.from_numpy(lbl).to(device) + + return img, lbl, scale + + +def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None): + """ + Creates a Cellpose network with a single input channel. + + Args: + device (str): The device to run the network on. + model_type (str, optional): The type of Cellpose model to use. Defaults to "cyto2". + pretrained_model (str, optional): The path to a pretrained model file. Defaults to None. + + Returns: + torch.nn.Module: The Cellpose network with a single input channel. + """ + if pretrained_model is not None and not os.path.exists(pretrained_model): + model_type = pretrained_model + pretrained_model = None + nbase = [32, 64, 128, 256] + nchan = 1 + net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device) + filename = model_path(model_type, + 0) if pretrained_model is None else pretrained_model + weights = torch.load(filename, weights_only=True) + zp = 0 + print(filename) + for name in net1.state_dict(): + if ("res_down_0.conv.conv_0" not in name and + #"output" not in name and + "res_down_0.proj" not in name and name != "diam_mean" and + name != "diam_labels"): + net1.state_dict()[name].copy_(weights[name]) + elif "res_down_0" in name: + if len(weights[name].shape) > 0: + new_weight = torch.zeros_like(net1.state_dict()[name]) + if weights[name].shape[0] == 2: + new_weight[:] = weights[name][0] + elif len(weights[name].shape) > 1 and weights[name].shape[1] == 2: + new_weight[:, zp] = weights[name][:, 0] + else: + new_weight = weights[name] + else: + new_weight = weights[name] + net1.state_dict()[name].copy_(new_weight) + return net1 + + +class CellposeDenoiseModel(): + """ model to run Cellpose and Image restoration """ + + def __init__(self, gpu=False, pretrained_model=False, model_type=None, + restore_type="denoise_cyto3", nchan=2, + chan2_restore=False, device=None): + + self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore, + device=device) + self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan, + pretrained_model=pretrained_model, device=device) + + def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, + normalize=True, rescale=None, diameter=None, tile_overlap=0.1, + augment=False, resample=True, invert=False, flow_threshold=0.4, + cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, + min_size=15, niter=None, interp=True, bsize=224, flow3D_smooth=0): + """ + Restore array or list of images using the image restoration model, and then segment. + + Args: + x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images + batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU + (can make smaller or bigger depending on GPU memory usage). Defaults to 8. + channels (list, optional): list of channels, either of length 2 or of length number of images by 2. + First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). + Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). + For instance, to segment grayscale images, input [0,0]. To segment images with cells + in green and nuclei in blue, input [2,3]. To segment one grayscale image and one + image with cells in green and nuclei in blue, input [[0,0], [2,3]]. + Defaults to None. + channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x. + if None, channels dimension is attempted to be automatically determined. Defaults to None. + z_axis (int, optional): z axis in element of list x, or of np.ndarray x. + if None, z dimension is attempted to be automatically determined. Defaults to None. + normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; + can also pass dictionary of parameters (all keys are optional, default values shown): + - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) + - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels + - "normalize"=True ; run normalization (if False, all following parameters ignored) + - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] + - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) + - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. + Defaults to True. + rescale (float, optional): resize factor for each image, if None, set to 1.0; + (only used if diameter is None). Defaults to None. + diameter (float, optional): diameter for each image, + if diameter is None, set to diam_mean or diam_train if available. Defaults to None. + tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. + augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False. + resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True. + invert (bool, optional): invert image pixel intensity before running network. Defaults to False. + flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4. + cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0. + do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False. + anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None. + stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0. + min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15. + flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0. + niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None. + interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True. + + Returns: + A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels; + flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration; + styles: style vector summarizing each image of size 256; + imgs: Restored images. + """ + + if isinstance(normalize, dict): + normalize_params = {**normalize_default, **normalize} + elif not isinstance(normalize, bool): + raise ValueError("normalize parameter must be a bool or a dict") + else: + normalize_params = normalize_default + normalize_params["normalize"] = normalize + normalize_params["invert"] = invert + + img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels, + channel_axis=channel_axis, z_axis=z_axis, + do_3D=do_3D, + normalize=normalize_params, rescale=rescale, + diameter=diameter, + tile_overlap=tile_overlap, bsize=bsize) + + # turn off special normalization for segmentation + normalize_params = normalize_default + + # change channels for segmentation + if channels is not None: + channels_new = [0, 0] if channels[0] == 0 else [1, 2] + else: + channels_new = None + # change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean) + diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter + masks, flows, styles = self.cp.eval( + img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1, + z_axis=0 if not isinstance(img_restore, list) and img_restore.ndim > 3 and img_restore.shape[0] > 0 else None, + normalize=normalize_params, rescale=rescale, diameter=diameter, + tile_overlap=tile_overlap, augment=augment, resample=resample, + invert=invert, flow_threshold=flow_threshold, + cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy, + stitch_threshold=stitch_threshold, min_size=min_size, niter=niter, + interp=interp, bsize=bsize) + + return masks, flows, styles, img_restore + + +class DenoiseModel(): + """ + DenoiseModel class for denoising images using Cellpose denoising model. + + Args: + gpu (bool, optional): Whether to use GPU for computation. Defaults to False. + pretrained_model (bool or str or Path, optional): Pretrained model to use for denoising. + Can be a string or path. Defaults to False. + nchan (int, optional): Number of channels in the input images, all Cellpose 3 models were trained with nchan=1. Defaults to 1. + model_type (str, optional): Type of pretrained model to use ("denoise_cyto3", "deblur_cyto3", "upsample_cyto3", ...). Defaults to None. + chan2 (bool, optional): Whether to use a separate model for the second channel. Defaults to False. + diam_mean (float, optional): Mean diameter of the objects in the images. Defaults to 30.0. + device (torch.device, optional): Device to use for computation. Defaults to None. + + Attributes: + nchan (int): Number of channels in the input images. + diam_mean (float): Mean diameter of the objects in the images. + net (CPnet): Cellpose network for denoising. + pretrained_model (bool or str or Path): Pretrained model path to use for denoising. + net_chan2 (CPnet or None): Cellpose network for the second channel, if applicable. + net_type (str): Type of the denoising network. + + Methods: + eval(x, batch_size=8, channels=None, channel_axis=None, z_axis=None, + normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1) + Denoise array or list of images using the denoising model. + + _eval(net, x, normalize=True, rescale=None, diameter=None, tile=True, + tile_overlap=0.1) + Run denoising model on a single channel. + """ + + def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None, + chan2=False, diam_mean=30., device=None): + self.nchan = nchan + if pretrained_model and (not isinstance(pretrained_model, str) and + not isinstance(pretrained_model, Path)): + raise ValueError("pretrained_model must be a string or path") + + self.diam_mean = diam_mean + builtin = True + if model_type is not None or (pretrained_model and + not os.path.exists(pretrained_model)): + pretrained_model_string = model_type if model_type is not None else "denoise_cyto3" + if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]): + pretrained_model_string = "denoise_cyto3" + pretrained_model = model_path(pretrained_model_string) + if (pretrained_model and not os.path.exists(pretrained_model)): + denoise_logger.warning("pretrained model has incorrect path") + denoise_logger.info(f">> {pretrained_model_string} << model set to be used") + self.diam_mean = 17. if "nuclei" in pretrained_model_string else 30. + else: + if pretrained_model: + builtin = False + pretrained_model_string = pretrained_model + denoise_logger.info(f">>>> loading model {pretrained_model_string}") + + # assign network device + if device is None: + sdevice, gpu = assign_device(use_torch=True, gpu=gpu) + self.device = device if device is not None else sdevice + if device is not None: + device_gpu = self.device.type == "cuda" + self.gpu = gpu if device is None else device_gpu + + # create network + self.nchan = nchan + self.nclasses = 1 + nbase = [32, 64, 128, 256] + self.nchan = nchan + self.nbase = [nchan, *nbase] + + self.net = CPnet(self.nbase, self.nclasses, sz=3, + max_pool=True, diam_mean=diam_mean).to(self.device) + + self.pretrained_model = pretrained_model + self.net_chan2 = None + if self.pretrained_model: + self.net.load_model(self.pretrained_model, device=self.device) + denoise_logger.info( + f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)" + ) + if chan2 and builtin: + chan2_path = model_path( + os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei") + print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}") + self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3, + max_pool=True, + diam_mean=17.).to(self.device) + self.net_chan2.load_model(chan2_path, device=self.device) + self.net_type = "cellpose_denoise" + + def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None, + normalize=True, rescale=None, diameter=None, tile=True, do_3D=False, + tile_overlap=0.1, bsize=224): + """ + Restore array or list of images using the image restoration model. + + Args: + x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images + batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU + (can make smaller or bigger depending on GPU memory usage). Defaults to 8. + channels (list, optional): list of channels, either of length 2 or of length number of images by 2. + First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue). + Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue). + For instance, to segment grayscale images, input [0,0]. To segment images with cells + in green and nuclei in blue, input [2,3]. To segment one grayscale image and one + image with cells in green and nuclei in blue, input [[0,0], [2,3]]. + Defaults to None. + channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x. + if None, channels dimension is attempted to be automatically determined. Defaults to None. + z_axis (int, optional): z axis in element of list x, or of np.ndarray x. + if None, z dimension is attempted to be automatically determined. Defaults to None. + normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; + can also pass dictionary of parameters (all keys are optional, default values shown): + - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) + - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels + - "normalize"=True ; run normalization (if False, all following parameters ignored) + - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] + - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) + - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. + Defaults to True. + rescale (float, optional): resize factor for each image, if None, set to 1.0; + (only used if diameter is None). Defaults to None. + diameter (float, optional): diameter for each image, + if diameter is None, set to diam_mean or diam_train if available. Defaults to None. + tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. + + Returns: + list: A list of 2D/3D arrays of restored images + + """ + if isinstance(x, list) or x.squeeze().ndim == 5: + tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO) + nimg = len(x) + iterator = trange(nimg, file=tqdm_out, + mininterval=30) if nimg > 1 else range(nimg) + imgs = [] + for i in iterator: + imgi = self.eval( + x[i], batch_size=batch_size, + channels=channels[i] if channels is not None and + ((len(channels) == len(x) and + (isinstance(channels[i], list) or + isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2)) + else channels, channel_axis=channel_axis, z_axis=z_axis, + normalize=normalize, + do_3D=do_3D, + rescale=rescale[i] if isinstance(rescale, list) or + isinstance(rescale, np.ndarray) else rescale, + diameter=diameter[i] if isinstance(diameter, list) or + isinstance(diameter, np.ndarray) else diameter, + tile_overlap=tile_overlap, bsize=bsize) + imgs.append(imgi) + if isinstance(x, np.ndarray): + imgs = np.array(imgs) + return imgs + + else: + # reshape image + x = transforms.convert_image(x, channels, channel_axis=channel_axis, + z_axis=z_axis, do_3D=do_3D, nchan=None) + if x.ndim < 4: + squeeze = True + x = x[np.newaxis, ...] + else: + squeeze = False + + # may need to interpolate image before running upsampling + self.ratio = 1. + if "upsample" in self.pretrained_model: + Ly, Lx = x.shape[-3:-1] + if diameter is not None and 3 <= diameter < self.diam_mean: + self.ratio = self.diam_mean / diameter + denoise_logger.info( + f"upsampling image to {self.diam_mean} pixel diameter ({self.ratio:0.2f} times)" + ) + Lyr, Lxr = int(Ly * self.ratio), int(Lx * self.ratio) + x = transforms.resize_image(x, Ly=Lyr, Lx=Lxr) + else: + denoise_logger.warning( + f"not interpolating image before upsampling because diameter is set >= {self.diam_mean}" + ) + #raise ValueError(f"diameter is set to {diameter}, needs to be >=3 and < {self.dn.diam_mean}") + + self.batch_size = batch_size + + if diameter is not None and diameter > 0: + rescale = self.diam_mean / diameter + elif rescale is None: + rescale = 1.0 + + if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0): + x = x[..., :1] + + for c in range(x.shape[-1]): + rescale0 = rescale * 30. / 17. if c == 1 else rescale + if c == 0 or self.net_chan2 is None: + x[..., + c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size, + normalize=normalize, rescale=rescale0, + tile_overlap=tile_overlap, bsize=bsize)[...,0] + else: + x[..., + c] = self._eval(self.net_chan2, x[..., + c:c + 1], batch_size=batch_size, + normalize=normalize, rescale=rescale0, + tile_overlap=tile_overlap, bsize=bsize)[...,0] + x = x[0] if squeeze else x + return x + + def _eval(self, net, x, batch_size=8, normalize=True, rescale=None, + tile_overlap=0.1, bsize=224): + """ + Run image restoration model on a single channel. + + Args: + x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images + batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU + (can make smaller or bigger depending on GPU memory usage). Defaults to 8. + normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel; + can also pass dictionary of parameters (all keys are optional, default values shown): + - "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored) + - "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels + - "normalize"=True ; run normalization (if False, all following parameters ignored) + - "percentile"=None : pass in percentiles to use as list [perc_low, perc_high] + - "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100) + - "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode. + Defaults to True. + rescale (float, optional): resize factor for each image, if None, set to 1.0; + (only used if diameter is None). Defaults to None. + tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1. + + Returns: + list: A list of 2D/3D arrays of restored images + + """ + if isinstance(normalize, dict): + normalize_params = {**normalize_default, **normalize} + elif not isinstance(normalize, bool): + raise ValueError("normalize parameter must be a bool or a dict") + else: + normalize_params = normalize_default + normalize_params["normalize"] = normalize + + tic = time.time() + shape = x.shape + nimg = shape[0] + + do_normalization = True if normalize_params["normalize"] else False + + img = np.asarray(x) + if do_normalization: + img = transforms.normalize_img(img, **normalize_params) + if rescale != 1.0: + img = transforms.resize_image(img, rsz=rescale) + yf, style = run_net(self.net, img, bsize=bsize, + tile_overlap=tile_overlap) + yf = transforms.resize_image(yf, shape[1], shape[2]) + imgs = yf + del yf, style + + # imgs = np.zeros((*x.shape[:-1], 1), np.float32) + # for i in iterator: + # img = np.asarray(x[i]) + # if do_normalization: + # img = transforms.normalize_img(img, **normalize_params) + # if rescale != 1.0: + # img = transforms.resize_image(img, rsz=[rescale, rescale]) + # if img.ndim == 2: + # img = img[:, :, np.newaxis] + # yf, style = run_net(net, img, batch_size=batch_size, augment=False, + # tile=tile, tile_overlap=tile_overlap, bsize=bsize) + # img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2]) + + # if img.ndim == 2: + # img = img[:, :, np.newaxis] + # imgs[i] = img + # del yf, style + net_time = time.time() - tic + if nimg > 1: + denoise_logger.info("imgs denoised in %2.2fs" % (net_time)) + + return imgs + + +def train(net, train_data=None, train_labels=None, train_files=None, test_data=None, + test_labels=None, test_files=None, train_probs=None, test_probs=None, + lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None, + save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0, + iso=True, uniform_blur=False, downsample=0., ds_max=7, + learning_rate=0.005, n_epochs=500, + weight_decay=0.00001, batch_size=8, nimg_per_epoch=None, + nimg_test_per_epoch=None, model_name=None): + + # net properties + device = net.device + nchan = net.nchan + diam_mean = net.diam_mean.item() + + args = np.array([poisson, beta, blur, gblur, downsample]) + if args.ndim == 1: + args = args[:, np.newaxis] + poisson, beta, blur, gblur, downsample = args + nnoise = len(poisson) + + d = datetime.datetime.now() + if save_path is not None: + if model_name is None: + filename = "" + lstrs = ["per", "seg", "rec"] + for k, (l, s) in enumerate(zip(lam, lstrs)): + filename += f"{s}_{l:.2f}_" + if not iso: + filename += "aniso_" + if poisson.sum() > 0: + filename += "poisson_" + if blur.sum() > 0: + filename += "blur_" + if downsample.sum() > 0: + filename += "downsample_" + filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f") + filename = os.path.join(save_path, filename) + else: + filename = os.path.join(save_path, model_name) + print(filename) + for i in range(len(poisson)): + denoise_logger.info( + f"poisson: {poisson[i]: 0.2f}, beta: {beta[i]: 0.2f}, blur: {blur[i]: 0.2f}, gblur: {gblur[i]: 0.2f}, downsample: {downsample[i]: 0.2f}" + ) + net1 = one_chan_cellpose(device=device, pretrained_model=seg_model_type) + + learning_rate_const = learning_rate + LR = np.linspace(0, learning_rate_const, 10) + LR = np.append(LR, learning_rate_const * np.ones(n_epochs - 100)) + for i in range(10): + LR = np.append(LR, LR[-1] / 2 * np.ones(10)) + learning_rate = LR + + batch_size = 8 + optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate[0], + weight_decay=weight_decay) + if train_data is not None: + nimg = len(train_data) + diam_train = np.array( + [utils.diameters(train_labels[k])[0] for k in trange(len(train_labels))]) + diam_train[diam_train < 5] = 5. + if test_data is not None: + diam_test = np.array( + [utils.diameters(test_labels[k])[0] for k in trange(len(test_labels))]) + diam_test[diam_test < 5] = 5. + nimg_test = len(test_data) + else: + nimg = len(train_files) + denoise_logger.info(">>> using files instead of loading dataset") + train_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in train_files] + denoise_logger.info(">>> computing diameters") + diam_train = np.array([ + utils.diameters(io.imread(train_labels_files[k])[0])[0] + for k in trange(len(train_labels_files)) + ]) + diam_train[diam_train < 5] = 5. + if test_files is not None: + nimg_test = len(test_files) + test_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in test_files] + diam_test = np.array([ + utils.diameters(io.imread(test_labels_files[k])[0])[0] + for k in trange(len(test_labels_files)) + ]) + diam_test[diam_test < 5] = 5. + train_probs = 1. / nimg * np.ones(nimg, + "float64") if train_probs is None else train_probs + if test_files is not None or test_data is not None: + test_probs = 1. / nimg_test * np.ones( + nimg_test, "float64") if test_probs is None else test_probs + + tic = time.time() + + nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch + if test_files is not None or test_data is not None: + nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch + + nbatch = 0 + train_losses, test_losses = [], [] + for iepoch in range(n_epochs): + np.random.seed(iepoch) + rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,), + p=train_probs) + torch.manual_seed(iepoch) + np.random.seed(iepoch) + for param_group in optimizer.param_groups: + param_group["lr"] = learning_rate[iepoch] + lavg, lavg_per, nsum = 0, 0, 0 + for ibatch in range(0, nimg_per_epoch, batch_size * nnoise): + inds = rperm[ibatch : ibatch + batch_size * nnoise] + if train_data is None: + imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds] + lbls = [io.imread(train_labels_files[i])[1:] for i in inds] + else: + imgs = [train_data[i][:nchan] for i in inds] + lbls = [train_labels[i][1:] for i in inds] + #inoise = nbatch % nnoise + rnoise = np.random.permutation(nnoise) + for i, inoise in enumerate(rnoise): + if i * batch_size < len(imgs): + imgi, lbli, scale = random_rotate_and_resize_noise( + imgs[i * batch_size : (i + 1) * batch_size], + lbls[i * batch_size : (i + 1) * batch_size], + diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(), + poisson=poisson[inoise], + beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso, + downsample=downsample[inoise], uniform_blur=uniform_blur, + diam_mean=diam_mean, ds_max=ds_max, + device=device) + if i == 0: + img = imgi + lbl = lbli + else: + img = torch.cat((img, imgi), axis=0) + lbl = torch.cat((lbl, lbli), axis=0) + + if nnoise > 0: + iperm = np.random.permutation(img.shape[0]) + img, lbl = img[iperm], lbl[iperm] + + for i in range(nnoise): + optimizer.zero_grad() + imgi = img[i * batch_size: (i + 1) * batch_size] + lbli = lbl[i * batch_size: (i + 1) * batch_size] + if imgi.shape[0] > 0: + loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1, + img=imgi[:, nchan:], lbl=lbli, lam=lam) + loss.backward() + optimizer.step() + lavg += loss.item() * imgi.shape[0] + lavg_per += loss_per.item() * imgi.shape[0] + + nsum += len(img) + nbatch += 1 + + if iepoch % 5 == 0 or iepoch < 10: + lavg = lavg / nsum + lavg_per = lavg_per / nsum + if test_data is not None or test_files is not None: + lavgt, nsum = 0., 0 + np.random.seed(42) + rperm = np.random.choice(np.arange(0, nimg_test), + size=(nimg_test_per_epoch,), p=test_probs) + inoise = iepoch % nnoise + torch.manual_seed(inoise) + for ibatch in range(0, nimg_test_per_epoch, batch_size): + inds = rperm[ibatch:ibatch + batch_size] + if test_data is None: + imgs = [ + np.maximum(0, + io.imread(test_files[i])[:nchan]) for i in inds + ] + lbls = [io.imread(test_labels_files[i])[1:] for i in inds] + else: + imgs = [test_data[i][:nchan] for i in inds] + lbls = [test_labels[i][1:] for i in inds] + img, lbl, scale = random_rotate_and_resize_noise( + imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise], + beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise], + iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur, + diam_mean=diam_mean, ds_max=ds_max, device=device) + loss, loss_per = test_loss(net, img[:, :nchan], net1=net1, + img=img[:, nchan:], lbl=lbl, lam=lam) + + lavgt += loss.item() * img.shape[0] + nsum += len(img) + lavgt = lavgt / nsum + denoise_logger.info( + "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f" + % (iepoch, time.time() - tic, lavg, lavg_per, lavgt, + learning_rate[iepoch])) + test_losses.append(lavgt) + else: + denoise_logger.info( + "Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" % + (iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch])) + train_losses.append(lavg) + + if save_path is not None: + if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0): + if save_each: #separate files as model progresses + filename0 = str(filename) + f"_epoch_{iepoch:%04d}" + else: + filename0 = filename + denoise_logger.info(f"saving network parameters to {filename0}") + net.save_model(filename0) + else: + filename = save_path + + return filename, train_losses, test_losses + + +if __name__ == "__main__": + import argparse + parser = argparse.ArgumentParser(description="cellpose parameters") + + input_img_args = parser.add_argument_group("input image arguments") + input_img_args.add_argument("--dir", default=[], type=str, + help="folder containing data to run or train on.") + input_img_args.add_argument("--img_filter", default=[], type=str, + help="end string for images to run on") + + model_args = parser.add_argument_group("model arguments") + model_args.add_argument("--pretrained_model", default=[], type=str, + help="pretrained denoising model") + + training_args = parser.add_argument_group("training arguments") + training_args.add_argument("--test_dir", default=[], type=str, + help="folder containing test data (optional)") + training_args.add_argument("--file_list", default=[], type=str, + help="npy file containing list of train and test files") + training_args.add_argument("--seg_model_type", default="cyto2", type=str, + help="model to use for seg training loss") + training_args.add_argument( + "--noise_type", default=[], type=str, + help="noise type to use (if input, then other noise params are ignored)") + training_args.add_argument("--poisson", default=0.8, type=float, + help="fraction of images to add poisson noise to") + training_args.add_argument("--beta", default=0.7, type=float, + help="scale of poisson noise") + training_args.add_argument("--blur", default=0., type=float, + help="fraction of images to blur") + training_args.add_argument("--gblur", default=1.0, type=float, + help="scale of gaussian blurring stddev") + training_args.add_argument("--downsample", default=0., type=float, + help="fraction of images to downsample") + training_args.add_argument("--ds_max", default=7, type=int, + help="max downsampling factor") + training_args.add_argument("--lam_per", default=1.0, type=float, + help="weighting of perceptual loss") + training_args.add_argument("--lam_seg", default=1.5, type=float, + help="weighting of segmentation loss") + training_args.add_argument("--lam_rec", default=0., type=float, + help="weighting of reconstruction loss") + training_args.add_argument( + "--diam_mean", default=30., type=float, help= + "mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0" + ) + training_args.add_argument("--learning_rate", default=0.001, type=float, + help="learning rate. Default: %(default)s") + training_args.add_argument("--n_epochs", default=2000, type=int, + help="number of epochs. Default: %(default)s") + training_args.add_argument( + "--save_each", default=False, action="store_true", + help="save each epoch as separate model") + training_args.add_argument( + "--nimg_per_epoch", default=0, type=int, + help="number of images per epoch. Default is length of training images") + training_args.add_argument( + "--nimg_test_per_epoch", default=0, type=int, + help="number of test images per epoch. Default is length of testing images") + + io.logger_setup() + + args = parser.parse_args() + lams = [args.lam_per, args.lam_seg, args.lam_rec] + print("lam", lams) + + if len(args.noise_type) > 0: + noise_type = args.noise_type + uniform_blur = False + iso = True + if noise_type == "poisson": + poisson = 0.8 + blur = 0. + downsample = 0. + beta = 0.7 + gblur = 1.0 + elif noise_type == "blur_expr": + poisson = 0.8 + blur = 0.8 + downsample = 0. + beta = 0.1 + gblur = 0.5 + elif noise_type == "blur": + poisson = 0.8 + blur = 0.8 + downsample = 0. + beta = 0.1 + gblur = 10.0 + uniform_blur = True + elif noise_type == "downsample_expr": + poisson = 0.8 + blur = 0.8 + downsample = 0.8 + beta = 0.03 + gblur = 1.0 + elif noise_type == "downsample": + poisson = 0.8 + blur = 0.8 + downsample = 0.8 + beta = 0.03 + gblur = 5.0 + uniform_blur = True + elif noise_type == "all": + poisson = [0.8, 0.8, 0.8] + blur = [0., 0.8, 0.8] + downsample = [0., 0., 0.8] + beta = [0.7, 0.1, 0.03] + gblur = [0., 10.0, 5.0] + uniform_blur = True + elif noise_type == "aniso": + poisson = 0.8 + blur = 0.8 + downsample = 0.8 + beta = 0.1 + gblur = args.ds_max * 1.5 + iso = False + else: + raise ValueError(f"{noise_type} noise_type is not supported") + else: + poisson, beta = args.poisson, args.beta + blur, gblur = args.blur, args.gblur + downsample = args.downsample + + pretrained_model = None if len( + args.pretrained_model) == 0 else args.pretrained_model + model = DenoiseModel(gpu=True, nchan=1, diam_mean=args.diam_mean, + pretrained_model=pretrained_model) + + train_data, labels, train_files, train_probs = None, None, None, None + test_data, test_labels, test_files, test_probs = None, None, None, None + if len(args.file_list) == 0: + output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0) + images, labels, image_names, test_images, test_labels, image_names_test = output + train_data = [] + for i in range(len(images)): + img = images[i].astype("float32") + if img.ndim > 2: + img = img[0] + train_data.append( + np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :]) + if len(args.test_dir) > 0: + test_data = [] + for i in range(len(test_images)): + img = test_images[i].astype("float32") + if img.ndim > 2: + img = img[0] + test_data.append( + np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :]) + save_path = os.path.join(args.dir, "../models/") + else: + root = args.dir + denoise_logger.info( + ">>> using file_list (assumes images are normalized and have flows!)") + dat = np.load(args.file_list, allow_pickle=True).item() + train_files = dat["train_files"] + test_files = dat["test_files"] + train_probs = dat["train_probs"] if "train_probs" in dat else None + test_probs = dat["test_probs"] if "test_probs" in dat else None + if str(train_files[0])[:len(str(root))] != str(root): + for i in range(len(train_files)): + new_path = root / Path(*train_files[i].parts[-3:]) + if i == 0: + print(f"changing path from {train_files[i]} to {new_path}") + train_files[i] = new_path + + for i in range(len(test_files)): + new_path = root / Path(*test_files[i].parts[-3:]) + test_files[i] = new_path + save_path = os.path.join(args.dir, "models/") + + os.makedirs(save_path, exist_ok=True) + + nimg_per_epoch = None if args.nimg_per_epoch == 0 else args.nimg_per_epoch + nimg_test_per_epoch = None if args.nimg_test_per_epoch == 0 else args.nimg_test_per_epoch + + model_path = train( + model.net, train_data=train_data, train_labels=labels, train_files=train_files, + test_data=test_data, test_labels=test_labels, test_files=test_files, + train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta, + blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max, + iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs, + learning_rate=args.learning_rate, + lam=lams, + seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch, + nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path) + + +def seg_train_noisy(model, train_data, train_labels, test_data=None, test_labels=None, + poisson=0.8, blur=0.0, downsample=0.0, save_path=None, + save_every=100, save_each=False, learning_rate=0.2, n_epochs=500, + momentum=0.9, weight_decay=0.00001, SGD=True, batch_size=8, + nimg_per_epoch=None, diameter=None, rescale=True, z_masking=False, + model_name=None): + """ train function uses loss function model.loss_fn in models.py + + (data should already be normalized) + + """ + + d = datetime.datetime.now() + + model.n_epochs = n_epochs + if isinstance(learning_rate, (list, np.ndarray)): + if isinstance(learning_rate, np.ndarray) and learning_rate.ndim > 1: + raise ValueError("learning_rate.ndim must equal 1") + elif len(learning_rate) != n_epochs: + raise ValueError( + "if learning_rate given as list or np.ndarray it must have length n_epochs" + ) + model.learning_rate = learning_rate + model.learning_rate_const = mode(learning_rate)[0][0] + else: + model.learning_rate_const = learning_rate + # set learning rate schedule + if SGD: + LR = np.linspace(0, model.learning_rate_const, 10) + if model.n_epochs > 250: + LR = np.append( + LR, model.learning_rate_const * np.ones(model.n_epochs - 100)) + for i in range(10): + LR = np.append(LR, LR[-1] / 2 * np.ones(10)) + else: + LR = np.append( + LR, + model.learning_rate_const * np.ones(max(0, model.n_epochs - 10))) + else: + LR = model.learning_rate_const * np.ones(model.n_epochs) + model.learning_rate = LR + + model.batch_size = batch_size + model._set_optimizer(model.learning_rate[0], momentum, weight_decay, SGD) + model._set_criterion() + + nimg = len(train_data) + + # compute average cell diameter + if diameter is None: + diam_train = np.array( + [utils.diameters(train_labels[k][0])[0] for k in range(len(train_labels))]) + diam_train_mean = diam_train[diam_train > 0].mean() + model.diam_labels = diam_train_mean + if rescale: + diam_train[diam_train < 5] = 5. + if test_data is not None: + diam_test = np.array([ + utils.diameters(test_labels[k][0])[0] + for k in range(len(test_labels)) + ]) + diam_test[diam_test < 5] = 5. + denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean) + elif rescale: + diam_train_mean = diameter + model.diam_labels = diameter + denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean) + diam_train = diameter * np.ones(len(train_labels), "float32") + if test_data is not None: + diam_test = diameter * np.ones(len(test_labels), "float32") + + denoise_logger.info( + f">>>> mean of training label mask diameters (saved to model) {diam_train_mean:.3f}" + ) + model.net.diam_labels.data = torch.ones(1, device=model.device) * diam_train_mean + + nchan = train_data[0].shape[0] + denoise_logger.info(">>>> training network with %d channel input <<<<" % nchan) + denoise_logger.info(">>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f" % + (model.learning_rate_const, model.batch_size, weight_decay)) + + if test_data is not None: + denoise_logger.info(f">>>> ntrain = {nimg}, ntest = {len(test_data)}") + else: + denoise_logger.info(f">>>> ntrain = {nimg}") + + tic = time.time() + + lavg, nsum = 0, 0 + + if save_path is not None: + _, file_label = os.path.split(save_path) + file_path = os.path.join(save_path, "models/") + + if not os.path.exists(file_path): + os.makedirs(file_path) + else: + denoise_logger.warning("WARNING: no save_path given, model not saving") + + ksave = 0 + + # get indices for each epoch for training + np.random.seed(0) + inds_all = np.zeros((0,), "int32") + if nimg_per_epoch is None or nimg > nimg_per_epoch: + nimg_per_epoch = nimg + denoise_logger.info(f">>>> nimg_per_epoch = {nimg_per_epoch}") + while len(inds_all) < n_epochs * nimg_per_epoch: + rperm = np.random.permutation(nimg) + inds_all = np.hstack((inds_all, rperm)) + + for iepoch in range(model.n_epochs): + if SGD: + model._set_learning_rate(model.learning_rate[iepoch]) + np.random.seed(iepoch) + rperm = inds_all[iepoch * nimg_per_epoch:(iepoch + 1) * nimg_per_epoch] + for ibatch in range(0, nimg_per_epoch, batch_size): + inds = rperm[ibatch:ibatch + batch_size] + imgi, lbl, scale = random_rotate_and_resize_noise( + [train_data[i] for i in inds], [train_labels[i][1:] for i in inds], + poisson=poisson, blur=blur, downsample=downsample, + diams=diam_train[inds], diam_mean=model.diam_mean) + imgi = imgi[:, :1] # keep noisy only + if z_masking: + nc = imgi.shape[1] + nb = imgi.shape[0] + ncmin = (np.random.rand(nb) > 0.25) * (np.random.randint( + nc // 2 - 1, size=nb)) + ncmax = nc - (np.random.rand(nb) > 0.25) * (np.random.randint( + nc // 2 - 1, size=nb)) + for b in range(nb): + imgi[b, :ncmin[b]] = 0 + imgi[b, ncmax[b]:] = 0 + + train_loss = model._train_step(imgi, lbl) + lavg += train_loss + nsum += len(imgi) + + if iepoch % 10 == 0 or iepoch == 5: + lavg = lavg / nsum + if test_data is not None: + lavgt, nsum = 0., 0 + np.random.seed(42) + rperm = np.arange(0, len(test_data), 1, int) + for ibatch in range(0, len(test_data), batch_size): + inds = rperm[ibatch:ibatch + batch_size] + imgi, lbl, scale = random_rotate_and_resize_noise( + [test_data[i] for i in inds], + [test_labels[i][1:] for i in inds], poisson=poisson, blur=blur, + downsample=downsample, diams=diam_test[inds], + diam_mean=model.diam_mean) + imgi = imgi[:, :1] # keep noisy only + test_loss = model._test_eval(imgi, lbl) + lavgt += test_loss + nsum += len(imgi) + + denoise_logger.info( + "Epoch %d, Time %4.1fs, Loss %2.4f, Loss Test %2.4f, LR %2.4f" % + (iepoch, time.time() - tic, lavg, lavgt / nsum, + model.learning_rate[iepoch])) + else: + denoise_logger.info( + "Epoch %d, Time %4.1fs, Loss %2.4f, LR %2.4f" % + (iepoch, time.time() - tic, lavg, model.learning_rate[iepoch])) + + lavg, nsum = 0, 0 + + if save_path is not None: + if iepoch == model.n_epochs - 1 or iepoch % save_every == 1: + # save model at the end + if save_each: #separate files as model progresses + if model_name is None: + filename = "{}_{}_{}_{}".format( + model.net_type, file_label, + d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch)) + else: + filename = "{}_{}".format(model_name, "epoch_" + str(iepoch)) + else: + if model_name is None: + filename = "{}_{}_{}".format(model.net_type, file_label, + d.strftime("%Y_%m_%d_%H_%M_%S.%f")) + else: + filename = model_name + filename = os.path.join(file_path, filename) + ksave += 1 + denoise_logger.info(f"saving network parameters to {filename}") + model.net.save_model(filename) + else: + filename = save_path + + return filename diff --git a/models/seg_post_model/cellpose/dynamics.py b/models/seg_post_model/cellpose/dynamics.py new file mode 100644 index 0000000000000000000000000000000000000000..39fad7dded616c167a0bcd1ae35d7cd5a4e2313d --- /dev/null +++ b/models/seg_post_model/cellpose/dynamics.py @@ -0,0 +1,691 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +import os +from scipy.ndimage import find_objects, center_of_mass, mean +import torch +import numpy as np +import tifffile +from tqdm import trange +import fastremap + +import logging + +dynamics_logger = logging.getLogger(__name__) + +from . import utils + +import torch +import torch.nn.functional as F + +def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200, + device=torch.device("cpu")): + """Runs diffusion on GPU to generate flows for training images or quality control. + + Args: + neighbors (torch.Tensor): 9 x pixels in masks. + meds (torch.Tensor): Mask centers. + isneighbor (torch.Tensor): Valid neighbor boolean 9 x pixels. + shape (tuple): Shape of the tensor. + n_iter (int, optional): Number of iterations. Defaults to 200. + device (torch.device, optional): Device to run the computation on. Defaults to torch.device("cpu"). + + Returns: + torch.Tensor: Generated flows. + + """ + if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps": + T = torch.zeros(shape, dtype=torch.float, device=device) + else: + T = torch.zeros(shape, dtype=torch.double, device=device) + + for i in range(n_iter): + T[tuple(meds.T)] += 1 + Tneigh = T[tuple(neighbors)] + Tneigh *= isneighbor + T[tuple(neighbors[:, 0])] = Tneigh.mean(axis=0) + del meds, isneighbor, Tneigh + + if T.ndim == 2: + grads = T[neighbors[0, [2, 1, 4, 3]], neighbors[1, [2, 1, 4, 3]]] + del neighbors + dy = grads[0] - grads[1] + dx = grads[2] - grads[3] + del grads + mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2) + else: + grads = T[tuple(neighbors[:, 1:])] + del neighbors + dz = grads[0] - grads[1] + dy = grads[2] - grads[3] + dx = grads[4] - grads[5] + del grads + mu_torch = np.stack( + (dz.cpu().squeeze(0), dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2) + return mu_torch + +def center_of_mass(mask): + yi, xi = np.nonzero(mask) + ymean = int(np.round(yi.sum() / len(yi))) + xmean = int(np.round(xi.sum() / len(xi))) + if not ((yi==ymean) * (xi==xmean)).sum(): + # center is closest point to (ymean, xmean) within mask + imin = ((xi - xmean)**2 + (yi - ymean)**2).argmin() + ymean = yi[imin] + xmean = xi[imin] + + return ymean, xmean + +def get_centers(masks, slices): + centers = [center_of_mass(masks[slices[i]]==(i+1)) for i in range(len(slices))] + centers = np.array([np.array([centers[i][0] + slices[i][0].start, centers[i][1] + slices[i][1].start]) + for i in range(len(slices))]) + exts = np.array([(slc[0].stop - slc[0].start) + (slc[1].stop - slc[1].start) + 2 for slc in slices]) + return centers, exts + + +def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None): + """Convert masks to flows using diffusion from center pixel. + + Center of masks where diffusion starts is defined by pixel closest to median within the mask. + + Args: + masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels. + device (torch.device, optional): The device to run the computation on. Defaults to torch.device("cpu"). + niter (int, optional): Number of iterations for the diffusion process. Defaults to None. + + Returns: + np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y. + + + Returns: + A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY. + meds_p are cell centers. + """ + if device is None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + + if masks.max() > 0: + Ly0, Lx0 = masks.shape + Ly, Lx = Ly0 + 2, Lx0 + 2 + + masks_padded = torch.from_numpy(masks.astype("int64")).to(device) + masks_padded = F.pad(masks_padded, (1, 1, 1, 1)) + shape = masks_padded.shape + + ### get mask pixel neighbors + y, x = torch.nonzero(masks_padded, as_tuple=True) + y = y.int() + x = x.int() + neighbors = torch.zeros((2, 9, y.shape[0]), dtype=torch.int, device=device) + yxi = [[0, -1, 1, 0, 0, -1, -1, 1, 1], [0, 0, 0, -1, 1, -1, 1, -1, 1]] + for i in range(9): + neighbors[0, i] = y + yxi[0][i] + neighbors[1, i] = x + yxi[1][i] + isneighbor = torch.ones((9, y.shape[0]), dtype=torch.bool, device=device) + m0 = masks_padded[neighbors[0, 0], neighbors[1, 0]] + for i in range(1, 9): + isneighbor[i] = masks_padded[neighbors[0, i], neighbors[1, i]] == m0 + del m0, masks_padded + + ### get center-of-mass within cell + slices = find_objects(masks) + centers, ext = get_centers(masks, slices) + meds_p = torch.from_numpy(centers).to(device).long() + meds_p += 1 # for padding + + ### run diffusion + n_iter = 2 * ext.max() if niter is None else niter + mu = _extend_centers_gpu(neighbors, meds_p, isneighbor, shape, n_iter=n_iter, + device=device) + mu = mu.astype("float64") + + # new normalization + mu /= (1e-60 + (mu**2).sum(axis=0)**0.5) + + # put into original image + mu0 = np.zeros((2, Ly0, Lx0)) + mu0[:, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu + else: + # no masks, return empty flows + mu0 = np.zeros((2, masks.shape[0], masks.shape[1])) + return mu0 + +def masks_to_flows_gpu_3d(masks, device=None, niter=None): + """Convert masks to flows using diffusion from center pixel. + + Args: + masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels. + device (torch.device, optional): The device to run the computation on. Defaults to None. + niter (int, optional): Number of iterations for the diffusion process. Defaults to None. + + Returns: + np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y. + + """ + if device is None: + device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None + + Lz0, Ly0, Lx0 = masks.shape + Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2 + + masks_padded = torch.from_numpy(masks.astype("int64")).to(device) + masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1)) + + # get mask pixel neighbors + z, y, x = torch.nonzero(masks_padded).T + neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z)) + neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0) + neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0) + + neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0) + + # get mask centers + slices = find_objects(masks) + + centers = np.zeros((masks.max(), 3), "int") + for i, si in enumerate(slices): + if si is not None: + sz, sy, sx = si + zi, yi, xi = np.nonzero(masks[sz, sy, sx] == (i + 1)) + zi = zi.astype(np.int32) + 1 # add padding + yi = yi.astype(np.int32) + 1 # add padding + xi = xi.astype(np.int32) + 1 # add padding + zmed = np.mean(zi) + ymed = np.mean(yi) + xmed = np.mean(xi) + imin = np.argmin((zi - zmed)**2 + (xi - xmed)**2 + (yi - ymed)**2) + zmed = zi[imin] + ymed = yi[imin] + xmed = xi[imin] + centers[i, 0] = zmed + sz.start + centers[i, 1] = ymed + sy.start + centers[i, 2] = xmed + sx.start + + # get neighbor validator (not all neighbors are in same mask) + neighbor_masks = masks_padded[tuple(neighbors)] + isneighbor = neighbor_masks == neighbor_masks[0] + ext = np.array( + [[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1] + for sz, sy, sx in slices]) + n_iter = 6 * (ext.sum(axis=1)).max() if niter is None else niter + + # run diffusion + shape = masks_padded.shape + mu = _extend_centers_gpu(neighbors, centers, isneighbor, shape, n_iter=n_iter, + device=device) + # normalize + mu /= (1e-60 + (mu**2).sum(axis=0)**0.5) + + # put into original image + mu0 = np.zeros((3, Lz0, Ly0, Lx0)) + mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu + return mu0 + +def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None, + return_flows=True): + """Converts labels (list of masks or flows) to flows for training model. + + Args: + labels (list of ND-arrays): The labels to convert. labels[k] can be 2D or 3D. If [3 x Ly x Lx], + it is assumed that flows were precomputed. Otherwise, labels[k][0] or labels[k] (if 2D) + is used to create flows and cell probabilities. + files (list of str, optional): The files to save the flows to. If provided, flows are saved to + files to be reused. Defaults to None. + device (str, optional): The device to use for computation. Defaults to None. + redo_flows (bool, optional): Whether to recompute the flows. Defaults to False. + niter (int, optional): The number of iterations for computing flows. Defaults to None. + + Returns: + list of [4 x Ly x Lx] arrays: The flows for training the model. flows[k][0] is labels[k], + flows[k][1] is cell distance transform, flows[k][2] is Y flow, flows[k][3] is X flow, + and flows[k][4] is heat distribution. + """ + nimg = len(labels) + if labels[0].ndim < 3: + labels = [labels[n][np.newaxis, :, :] for n in range(nimg)] + + flows = [] + # flows need to be recomputed + if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows: + dynamics_logger.info("computing flows for labels") + + # compute flows; labels are fixed here to be unique, so they need to be passed back + # make sure labels are unique! + labels = [fastremap.renumber(label, in_place=True)[0] for label in labels] + iterator = trange if nimg > 1 else range + for n in iterator(nimg): + labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0] + vecn = masks_to_flows_gpu(labels[n][0].astype(int), device=device, niter=niter) + + # concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations) + flow = np.concatenate((labels[n], labels[n] > 0.5, vecn), + axis=0).astype(np.float32) + if files is not None: + file_name = os.path.splitext(files[n])[0] + tifffile.imwrite(file_name + "_flows.tif", flow) + if return_flows: + flows.append(flow) + else: + dynamics_logger.info("flows precomputed") + if return_flows: + flows = [labels[n].astype(np.float32) for n in range(nimg)] + return flows + + +def flow_error(maski, dP_net, device=None): + """Error in flows from predicted masks vs flows predicted by network run on image. + + This function serves to benchmark the quality of masks. It works as follows: + 1. The predicted masks are used to create a flow diagram. + 2. The mask-flows are compared to the flows that the network predicted. + + If there is a discrepancy between the flows, it suggests that the mask is incorrect. + Masks with flow_errors greater than 0.4 are discarded by default. This setting can be + changed in Cellpose.eval or CellposeModel.eval. + + Args: + maski (np.ndarray, int): Masks produced from running dynamics on dP_net, where 0=NO masks; 1,2... are mask labels. + dP_net (np.ndarray, float): ND flows where dP_net.shape[1:] = maski.shape. + + Returns: + A tuple containing (flow_errors, dP_masks): flow_errors (np.ndarray, float): Mean squared error between predicted flows and flows from masks; + dP_masks (np.ndarray, float): ND flows produced from the predicted masks. + """ + if dP_net.shape[1:] != maski.shape: + print("ERROR: net flow is not same size as predicted masks") + return + + # flows predicted from estimated masks + dP_masks = masks_to_flows_gpu(maski, device=device) + # difference between predicted flows vs mask flows + flow_errors = np.zeros(maski.max()) + for i in range(dP_masks.shape[0]): + flow_errors += mean((dP_masks[i] - dP_net[i] / 5.)**2, maski, + index=np.arange(1, + maski.max() + 1)) + + return flow_errors, dP_masks + + +def steps_interp(dP, inds, niter, device=torch.device("cpu")): + """ Run dynamics of pixels to recover masks in 2D/3D, with interpolation between pixel values. + + Euler integration of dynamics dP for niter steps. + + Args: + p (numpy.ndarray): Array of shape (n_points, 2 or 3) representing the initial pixel locations. + dP (numpy.ndarray): Array of shape (2, Ly, Lx) or (3, Lz, Ly, Lx) representing the flow field. + niter (int): Number of iterations to perform. + device (torch.device, optional): Device to use for computation. Defaults to None. + + Returns: + numpy.ndarray: Array of shape (n_points, 2) or (n_points, 3) representing the final pixel locations. + + Raises: + None + + """ + + shape = dP.shape[1:] + ndim = len(shape) + + pt = torch.zeros((*[1]*ndim, len(inds[0]), ndim), dtype=torch.float32, device=device) + im = torch.zeros((1, ndim, *shape), dtype=torch.float32, device=device) + # Y and X dimensions, flipped X-1, Y-1 + # pt is [1 1 1 3 n_points] + for n in range(ndim): + if ndim==3: + pt[0, 0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32) + else: + pt[0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32) + im[0, ndim - n - 1] = torch.from_numpy(dP[n]).to(device, dtype=torch.float32) + shape = np.array(shape)[::-1].astype("float") - 1 + + # normalize pt between 0 and 1, normalize the flow + for k in range(ndim): + im[:, k] *= 2. / shape[k] + pt[..., k] /= shape[k] + + # normalize to between -1 and 1 + pt *= 2 + pt -= 1 + + # dynamics + for t in range(niter): + dPt = torch.nn.functional.grid_sample(im, pt, align_corners=False) + for k in range(ndim): #clamp the final pixel locations + pt[..., k] = torch.clamp(pt[..., k] + dPt[:, k], -1., 1.) + + #undo the normalization from before, reverse order of operations + pt += 1 + pt *= 0.5 + for k in range(ndim): + pt[..., k] *= shape[k] + + if ndim==3: + pt = pt[..., [2, 1, 0]].squeeze() + pt = pt.unsqueeze(0) if pt.ndim==1 else pt + return pt.T + else: + pt = pt[..., [1, 0]].squeeze() + pt = pt.unsqueeze(0) if pt.ndim==1 else pt + return pt.T + +def follow_flows(dP, inds, niter=200, device=torch.device("cpu")): + """ Run dynamics to recover masks in 2D or 3D. + + Pixels are represented as a meshgrid. Only pixels with non-zero cell-probability + are used (as defined by inds). + + Args: + dP (np.ndarray): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + mask (np.ndarray, optional): Pixel mask to seed masks. Useful when flows have low magnitudes. + niter (int, optional): Number of iterations of dynamics to run. Default is 200. + interp (bool, optional): Interpolate during 2D dynamics (not available in 3D). Default is True. + device (torch.device, optional): Device to use for computation. Default is None. + + Returns: + A tuple containing (p, inds): p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx]; + inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + """ + shape = np.array(dP.shape[1:]).astype(np.int32) + ndim = len(inds) + + p = steps_interp(dP, inds, niter, device=device) + + return p + + +def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")): + """Remove masks which have inconsistent flows. + + Uses metrics.flow_error to compute flows from predicted masks + and compare flows to predicted flows from the network. Discards + masks with flow errors greater than the threshold. + + Args: + masks (int, 2D or 3D array): Labelled masks, 0=NO masks; 1,2,...=mask labels, + size [Ly x Lx] or [Lz x Ly x Lx]. + flows (float, 3D or 4D array): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + threshold (float, optional): Masks with flow error greater than threshold are discarded. + Default is 0.4. + + Returns: + masks (int, 2D or 3D array): Masks with inconsistent flow masks removed, + 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx]. + """ + device0 = device + if masks.size > 10000 * 10000 and (device is not None and device.type == "cuda"): + + major_version, minor_version = torch.__version__.split(".")[:2] + torch.cuda.empty_cache() + if major_version == "1" and int(minor_version) < 10: + # for PyTorch version lower than 1.10 + def mem_info(): + total_mem = torch.cuda.get_device_properties(device0.index).total_memory + used_mem = torch.cuda.memory_allocated(device0.index) + free_mem = total_mem - used_mem + return total_mem, free_mem + else: + # for PyTorch version 1.10 and above + def mem_info(): + free_mem, total_mem = torch.cuda.mem_get_info(device0.index) + return total_mem, free_mem + total_mem, free_mem = mem_info() + if masks.size * 32 > free_mem: + dynamics_logger.warning( + "WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold" + ) + dynamics_logger.info("turn off QC step with flow_threshold=0 if too slow") + device0 = torch.device("cpu") + + merrors, _ = flow_error(masks, flows, device0) + badi = 1 + (merrors > threshold).nonzero()[0] + masks[np.isin(masks, badi)] = 0 + return masks + + +def max_pool1d(h, kernel_size=5, axis=1, out=None): + """ memory efficient max_pool thanks to Mark Kittisopikul + + for stride=1, padding=kernel_size//2, requires odd kernel_size >= 3 + + """ + if out is None: + out = h.clone() + else: + out.copy_(h) + + nd = h.shape[axis] + k0 = kernel_size // 2 + for d in range(-k0, k0+1): + if axis==1: + mv = out[:, max(-d,0):min(nd-d,nd)] + hv = h[:, max(d,0):min(nd+d,nd)] + elif axis==2: + mv = out[:, :, max(-d,0):min(nd-d,nd)] + hv = h[:, :, max(d,0):min(nd+d,nd)] + elif axis==3: + mv = out[:, :, :, max(-d,0):min(nd-d,nd)] + hv = h[:, :, :, max(d,0):min(nd+d,nd)] + torch.maximum(mv, hv, out=mv) + return out + +def max_pool_nd(h, kernel_size=5): + """ memory efficient max_pool in 2d or 3d """ + ndim = h.ndim - 1 + hmax = max_pool1d(h, kernel_size=kernel_size, axis=1) + hmax2 = max_pool1d(hmax, kernel_size=kernel_size, axis=2) + if ndim==2: + del hmax + return hmax2 + else: + hmax = max_pool1d(hmax2, kernel_size=kernel_size, axis=3, out=hmax) + del hmax2 + return hmax + +def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4): + """Create masks using pixel convergence after running dynamics. + + Makes a histogram of final pixel locations p, initializes masks + at peaks of histogram and extends the masks from the peaks so that + they include all pixels with more than 2 final pixels p. Discards + masks with flow errors greater than the threshold. + + Parameters: + p (float32, 3D or 4D array): Final locations of each pixel after dynamics, + size [axis x Ly x Lx] or [axis x Lz x Ly x Lx]. + iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are + iscell False to stay in their original location. + rpad (int, optional): Histogram edge padding. Default is 20. + max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. + + Returns: + M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed, + 0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx]. + """ + + ndim = len(shape0) + device = pt.device + + rpad = 20 + pt += rpad + pt = torch.clamp(pt, min=0) + for i in range(len(pt)): + pt[i] = torch.clamp(pt[i], max=shape0[i]+rpad-1) + + # # add extra padding to make divisible by 5 + # shape = tuple((np.ceil((shape0 + 2*rpad)/5) * 5).astype(int)) + shape = tuple(np.array(shape0) + 2*rpad) + + # sparse coo torch + coo = torch.sparse_coo_tensor(pt, torch.ones(pt.shape[1], device=pt.device, dtype=torch.int), + shape) + h1 = coo.to_dense() + del coo + + hmax1 = max_pool_nd(h1.unsqueeze(0), kernel_size=5) + hmax1 = hmax1.squeeze() + seeds1 = torch.nonzero((h1 - hmax1 > -1e-6) * (h1 > 10)) + del hmax1 + if len(seeds1) == 0: + dynamics_logger.warning("no seeds found in get_masks_torch - no masks found.") + return np.zeros(shape0, dtype="uint16") + + npts = h1[tuple(seeds1.T)] + isort1 = npts.argsort() + seeds1 = seeds1[isort1] + + n_seeds = len(seeds1) + h_slc = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device) + for k in range(n_seeds): + slc = tuple([slice(seeds1[k][j]-5, seeds1[k][j]+6) for j in range(ndim)]) + h_slc[k] = h1[slc] + del h1 + seed_masks = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device) + if ndim==2: + seed_masks[:,5,5] = 1 + else: + seed_masks[:,5,5,5] = 1 + + for iter in range(5): + # extend + seed_masks = max_pool_nd(seed_masks, kernel_size=3) + seed_masks *= h_slc > 2 + del h_slc + seeds_new = [tuple((torch.nonzero(seed_masks[k]) + seeds1[k] - 5).T) + for k in range(n_seeds)] + del seed_masks + + dtype = torch.int32 if n_seeds < 2**16 else torch.int64 + M1 = torch.zeros(shape, dtype=dtype, device=device) + for k in range(n_seeds): + M1[seeds_new[k]] = 1 + k + + M1 = M1[tuple(pt)] + M1 = M1.cpu().numpy() + + dtype = "uint16" if n_seeds < 2**16 else "uint32" + M0 = np.zeros(shape0, dtype=dtype) + M0[inds] = M1 + + # remove big masks + uniq, counts = fastremap.unique(M0, return_counts=True) + big = np.prod(shape0) * max_size_fraction + bigc = uniq[counts > big] + if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0): + M0 = fastremap.mask(M0, bigc) + fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels + M0 = M0.reshape(tuple(shape0)) + + #print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9:.3f} gb") + return M0 + + +def resize_and_compute_masks(dP, cellprob, niter=200, cellprob_threshold=0.0, + flow_threshold=0.4, do_3D=False, min_size=15, + max_size_fraction=0.4, resize=None, device=torch.device("cpu")): + """Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None. + + Args: + dP (numpy.ndarray): The dynamics flow field array. + cellprob (numpy.ndarray): The cell probability array. + p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None + niter (int, optional): The number of iterations for mask computation. Defaults to 200. + cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0. + flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4. + interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True. + do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False. + min_size (int, optional): The minimum size of the masks. Defaults to 15. + max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. + resize (tuple, optional): The desired size for resizing the masks. Defaults to None. + device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu"). + + Returns: + tuple: A tuple containing the computed masks and the final pixel locations. + """ + mask = compute_masks(dP, cellprob, niter=niter, + cellprob_threshold=cellprob_threshold, + flow_threshold=flow_threshold, do_3D=do_3D, + max_size_fraction=max_size_fraction, + device=device) + + if resize is not None: + dynamics_logger.warning("Resizing is depricated in v4.0.1+") + + mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size) + + return mask + + +def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0, + flow_threshold=0.4, do_3D=False, min_size=-1, + max_size_fraction=0.4, device=torch.device("cpu")): + """Compute masks using dynamics from dP and cellprob. + + Args: + dP (numpy.ndarray): The dynamics flow field array. + cellprob (numpy.ndarray): The cell probability array. + p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None + niter (int, optional): The number of iterations for mask computation. Defaults to 200. + cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0. + flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4. + interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True. + do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False. + min_size (int, optional): The minimum size of the masks. Defaults to 15. + max_size_fraction (float, optional): Masks larger than max_size_fraction of + total image size are removed. Default is 0.4. + device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu"). + + Returns: + tuple: A tuple containing the computed masks and the final pixel locations. + """ + + if (cellprob > cellprob_threshold).sum(): #mask at this point is a cell cluster binary map, not labels + inds = np.nonzero(cellprob > cellprob_threshold) + if len(inds[0]) == 0: + dynamics_logger.info("No cell pixels found.") + shape = cellprob.shape + mask = np.zeros(shape, "uint16") + return mask + + p_final = follow_flows(dP * (cellprob > cellprob_threshold) / 5., + inds=inds, niter=niter, + device=device) + if not torch.is_tensor(p_final): + p_final = torch.from_numpy(p_final).to(device, dtype=torch.int) + else: + p_final = p_final.int() + # calculate masks + if device.type == "mps": + p_final = p_final.to(torch.device("cpu")) + mask = get_masks_torch(p_final, inds, dP.shape[1:], + max_size_fraction=max_size_fraction) + del p_final + # flow thresholding factored out of get_masks + if not do_3D: + if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0: + # make sure labels are unique at output of get_masks + mask = remove_bad_flow_masks(mask, dP, threshold=flow_threshold, + device=device) + + if mask.max() < 2**16 and mask.dtype != "uint16": + mask = mask.astype("uint16") + + else: # nothing to compute, just make it compatible + dynamics_logger.info("No cell pixels found.") + shape = cellprob.shape + mask = np.zeros(cellprob.shape, "uint16") + return mask + + if min_size > 0: + mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size) + + if mask.dtype == np.uint32: + dynamics_logger.warning( + "more than 65535 masks in image, masks returned as np.uint32") + + return mask diff --git a/models/seg_post_model/cellpose/export.py b/models/seg_post_model/cellpose/export.py new file mode 100644 index 0000000000000000000000000000000000000000..9052a3b44b6e122085792b922c3bcceaaa9d8166 --- /dev/null +++ b/models/seg_post_model/cellpose/export.py @@ -0,0 +1,405 @@ +"""Auxiliary module for bioimageio format export + +Example usage: + +```bash +#!/bin/bash + +# Define default paths and parameters +DEFAULT_CHANNELS="1 0" +DEFAULT_PATH_PRETRAINED_MODEL="/home/qinyu/models/cp/cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995" +DEFAULT_PATH_README="/home/qinyu/models/cp/README.md" +DEFAULT_LIST_PATH_COVER_IMAGES="/home/qinyu/images/cp/cellpose_raw_and_segmentation.jpg /home/qinyu/images/cp/cellpose_raw_and_probability.jpg /home/qinyu/images/cp/cellpose_raw.jpg" +DEFAULT_MODEL_ID="philosophical-panda" +DEFAULT_MODEL_ICON="🐼" +DEFAULT_MODEL_VERSION="0.1.0" +DEFAULT_MODEL_NAME="My Cool Cellpose" +DEFAULT_MODEL_DOCUMENTATION="A cool Cellpose model trained for my cool dataset." +DEFAULT_MODEL_AUTHORS='[{"name": "Qin Yu", "affiliation": "EMBL", "github_user": "qin-yu", "orcid": "0000-0002-4652-0795"}]' +DEFAULT_MODEL_CITE='[{"text": "For more details of the model itself, see the manuscript", "doi": "10.1242/dev.202800", "url": null}]' +DEFAULT_MODEL_TAGS="cellpose 3d 2d" +DEFAULT_MODEL_LICENSE="MIT" +DEFAULT_MODEL_REPO="https://github.com/kreshuklab/go-nuclear" + +# Run the Python script with default parameters +python export.py \ + --channels $DEFAULT_CHANNELS \ + --path_pretrained_model "$DEFAULT_PATH_PRETRAINED_MODEL" \ + --path_readme "$DEFAULT_PATH_README" \ + --list_path_cover_images $DEFAULT_LIST_PATH_COVER_IMAGES \ + --model_version "$DEFAULT_MODEL_VERSION" \ + --model_name "$DEFAULT_MODEL_NAME" \ + --model_documentation "$DEFAULT_MODEL_DOCUMENTATION" \ + --model_authors "$DEFAULT_MODEL_AUTHORS" \ + --model_cite "$DEFAULT_MODEL_CITE" \ + --model_tags $DEFAULT_MODEL_TAGS \ + --model_license "$DEFAULT_MODEL_LICENSE" \ + --model_repo "$DEFAULT_MODEL_REPO" +``` +""" + +import os +import sys +import json +import argparse +from pathlib import Path +from urllib.parse import urlparse + +import torch +import numpy as np + +from cellpose.io import imread +from cellpose.utils import download_url_to_file +from cellpose.transforms import pad_image_ND, normalize_img, convert_image +from cellpose.vit_sam import CPnetBioImageIO + +from bioimageio.spec.model.v0_5 import ( + ArchitectureFromFileDescr, + Author, + AxisId, + ChannelAxis, + CiteEntry, + Doi, + FileDescr, + Identifier, + InputTensorDescr, + IntervalOrRatioDataDescr, + LicenseId, + ModelDescr, + ModelId, + OrcidId, + OutputTensorDescr, + ParameterizedSize, + PytorchStateDictWeightsDescr, + SizeReference, + SpaceInputAxis, + SpaceOutputAxis, + TensorId, + TorchscriptWeightsDescr, + Version, + WeightsDescr, +) +# Define ARBITRARY_SIZE if it is not available in the module +try: + from bioimageio.spec.model.v0_5 import ARBITRARY_SIZE +except ImportError: + ARBITRARY_SIZE = ParameterizedSize(min=1, step=1) + +from bioimageio.spec.common import HttpUrl +from bioimageio.spec import save_bioimageio_package +from bioimageio.core import test_model + +DEFAULT_CHANNELS = [2, 1] +DEFAULT_NORMALIZE_PARAMS = { + "axis": -1, + "lowhigh": None, + "percentile": None, + "normalize": True, + "norm3D": False, + "sharpen_radius": 0, + "smooth_radius": 0, + "tile_norm_blocksize": 0, + "tile_norm_smooth3D": 1, + "invert": False, +} +IMAGE_URL = "http://www.cellpose.org/static/data/rgb_3D.tif" + + +def download_and_normalize_image(path_dir_temp, channels=DEFAULT_CHANNELS): + """ + Download and normalize image. + """ + filename = os.path.basename(urlparse(IMAGE_URL).path) + path_image = path_dir_temp / filename + if not path_image.exists(): + sys.stderr.write(f'Downloading: "{IMAGE_URL}" to {path_image}\n') + download_url_to_file(IMAGE_URL, path_image) + img = imread(path_image).astype(np.float32) + img = convert_image(img, channels, channel_axis=1, z_axis=0, do_3D=False, nchan=2) + img = normalize_img(img, **DEFAULT_NORMALIZE_PARAMS) + img = np.transpose(img, (0, 3, 1, 2)) + img, _, _ = pad_image_ND(img) + return img + + +def load_bioimageio_cpnet_model(path_model_weight, nchan=2): + cpnet_kwargs = { + "nout": 3, + } + cpnet_biio = CPnetBioImageIO(**cpnet_kwargs) + state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True) + cpnet_biio.load_state_dict(state_dict_cuda) + cpnet_biio.eval() # crucial for the prediction results + return cpnet_biio, cpnet_kwargs + + +def descr_gen_input(path_test_input, nchan=2): + input_axes = [ + SpaceInputAxis(id=AxisId("z"), size=ARBITRARY_SIZE), + ChannelAxis(channel_names=[Identifier(f"c{i+1}") for i in range(nchan)]), + SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=16, step=16)), + SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=16, step=16)), + ] + data_descr = IntervalOrRatioDataDescr(type="float32") + path_test_input = Path(path_test_input) + descr_input = InputTensorDescr( + id=TensorId("raw"), + axes=input_axes, + test_tensor=FileDescr(source=path_test_input), + data=data_descr, + ) + return descr_input + + +def descr_gen_output_flow(path_test_output): + output_axes_output_tensor = [ + SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))), + ChannelAxis(channel_names=[Identifier("flow1"), Identifier("flow2"), Identifier("flow3")]), + SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))), + SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x"))), + ] + path_test_output = Path(path_test_output) + descr_output = OutputTensorDescr( + id=TensorId("flow"), + axes=output_axes_output_tensor, + test_tensor=FileDescr(source=path_test_output), + ) + return descr_output + + +def descr_gen_output_downsampled(path_dir_temp, nbase=None): + if nbase is None: + nbase = [32, 64, 128, 256] + + output_axes_downsampled_tensors = [ + [ + SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))), + ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(base)]), + SpaceOutputAxis( + id=AxisId("y"), + size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y")), + scale=2**offset, + ), + SpaceOutputAxis( + id=AxisId("x"), + size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")), + scale=2**offset, + ), + ] + for offset, base in enumerate(nbase) + ] + path_downsampled_tensors = [ + Path(path_dir_temp / f"test_downsampled_{i}.npy") for i in range(len(output_axes_downsampled_tensors)) + ] + descr_output_downsampled_tensors = [ + OutputTensorDescr( + id=TensorId(f"downsampled_{i}"), + axes=axes, + test_tensor=FileDescr(source=path), + ) + for i, (axes, path) in enumerate(zip(output_axes_downsampled_tensors, path_downsampled_tensors)) + ] + return descr_output_downsampled_tensors + + +def descr_gen_output_style(path_test_style, nchannel=256): + output_axes_style_tensor = [ + SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))), + ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(nchannel)]), + ] + path_style_tensor = Path(path_test_style) + descr_output_style_tensor = OutputTensorDescr( + id=TensorId("style"), + axes=output_axes_style_tensor, + test_tensor=FileDescr(source=path_style_tensor), + ) + return descr_output_style_tensor + + +def descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper=None): + if path_cpnet_wrapper is None: + path_cpnet_wrapper = Path(__file__).parent / "resnet_torch.py" + pytorch_architecture = ArchitectureFromFileDescr( + callable=Identifier("CPnetBioImageIO"), + source=Path(path_cpnet_wrapper), + kwargs=cpnet_kwargs, + ) + return pytorch_architecture + + +def descr_gen_documentation(path_doc, markdown_text): + with open(path_doc, "w") as f: + f.write(markdown_text) + + +def package_to_bioimageio( + path_pretrained_model, + path_save_trace, + path_readme, + list_path_cover_images, + descr_input, + descr_output, + descr_output_downsampled_tensors, + descr_output_style_tensor, + pytorch_version, + pytorch_architecture, + model_id, + model_icon, + model_version, + model_name, + model_documentation, + model_authors, + model_cite, + model_tags, + model_license, + model_repo, +): + """Package model description to BioImage.IO format.""" + my_model_descr = ModelDescr( + id=ModelId(model_id) if model_id is not None else None, + id_emoji=model_icon, + version=Version(model_version), + name=model_name, + description=model_documentation, + authors=[ + Author( + name=author["name"], + affiliation=author["affiliation"], + github_user=author["github_user"], + orcid=OrcidId(author["orcid"]), + ) + for author in model_authors + ], + cite=[CiteEntry(text=cite["text"], doi=Doi(cite["doi"]), url=cite["url"]) for cite in model_cite], + covers=[Path(img) for img in list_path_cover_images], + license=LicenseId(model_license), + tags=model_tags, + documentation=Path(path_readme), + git_repo=HttpUrl(model_repo), + inputs=[descr_input], + outputs=[descr_output, descr_output_style_tensor] + descr_output_downsampled_tensors, + weights=WeightsDescr( + pytorch_state_dict=PytorchStateDictWeightsDescr( + source=Path(path_pretrained_model), + architecture=pytorch_architecture, + pytorch_version=pytorch_version, + ), + torchscript=TorchscriptWeightsDescr( + source=Path(path_save_trace), + pytorch_version=pytorch_version, + parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights. + ), + ), + ) + + return my_model_descr + + +def parse_args(): + # fmt: off + parser = argparse.ArgumentParser(description="BioImage.IO model packaging for Cellpose") + parser.add_argument("--channels", nargs=2, default=[2, 1], type=int, help="Cyto-only = [2, 0], Cyto + Nuclei = [2, 1], Nuclei-only = [1, 0]") + parser.add_argument("--path_pretrained_model", required=True, type=str, help="Path to pretrained model file, e.g., cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995") + parser.add_argument("--path_readme", required=True, type=str, help="Path to README file") + parser.add_argument("--list_path_cover_images", nargs='+', required=True, type=str, help="List of paths to cover images") + parser.add_argument("--model_id", type=str, help="Model ID, provide if already exists", default=None) + parser.add_argument("--model_icon", type=str, help="Model icon, provide if already exists", default=None) + parser.add_argument("--model_version", required=True, type=str, help="Model version, new model should be 0.1.0") + parser.add_argument("--model_name", required=True, type=str, help="Model name, e.g., My Cool Cellpose") + parser.add_argument("--model_documentation", required=True, type=str, help="Model documentation, e.g., A cool Cellpose model trained for my cool dataset.") + parser.add_argument("--model_authors", required=True, type=str, help="Model authors in JSON format, e.g., '[{\"name\": \"Qin Yu\", \"affiliation\": \"EMBL\", \"github_user\": \"qin-yu\", \"orcid\": \"0000-0002-4652-0795\"}]'") + parser.add_argument("--model_cite", required=True, type=str, help="Model citation in JSON format, e.g., '[{\"text\": \"For more details of the model itself, see the manuscript\", \"doi\": \"10.1242/dev.202800\", \"url\": null}]'") + parser.add_argument("--model_tags", nargs='+', required=True, type=str, help="Model tags, e.g., cellpose 3d 2d") + parser.add_argument("--model_license", required=True, type=str, help="Model license, e.g., MIT") + parser.add_argument("--model_repo", required=True, type=str, help="Model repository URL") + return parser.parse_args() + # fmt: on + + +def main(): + args = parse_args() + + # Parse user-provided paths and arguments + channels = args.channels + model_cite = json.loads(args.model_cite) + model_authors = json.loads(args.model_authors) + + path_readme = Path(args.path_readme) + path_pretrained_model = Path(args.path_pretrained_model) + list_path_cover_images = [Path(path_image) for path_image in args.list_path_cover_images] + + # Auto-generated paths + path_cpnet_wrapper = Path(__file__).resolve().parent / "resnet_torch.py" + path_dir_temp = Path(__file__).resolve().parent.parent / "models" / path_pretrained_model.stem + path_dir_temp.mkdir(parents=True, exist_ok=True) + + path_save_trace = path_dir_temp / "cp_traced.pt" + path_test_input = path_dir_temp / "test_input.npy" + path_test_output = path_dir_temp / "test_output.npy" + path_test_style = path_dir_temp / "test_style.npy" + path_bioimageio_package = path_dir_temp / "cellpose_model.zip" + + # Download test input image + img_np = download_and_normalize_image(path_dir_temp, channels=channels) + np.save(path_test_input, img_np) + img = torch.tensor(img_np).float() + + # Load model + cpnet_biio, cpnet_kwargs = load_bioimageio_cpnet_model(path_pretrained_model) + + # Test model and save output + tuple_output_tensor = cpnet_biio(img) + np.save(path_test_output, tuple_output_tensor[0].detach().numpy()) + np.save(path_test_style, tuple_output_tensor[1].detach().numpy()) + for i, t in enumerate(tuple_output_tensor[2:]): + np.save(path_dir_temp / f"test_downsampled_{i}.npy", t.detach().numpy()) + + # Save traced model + model_traced = torch.jit.trace(cpnet_biio, img) + model_traced.save(path_save_trace) + + # Generate model description + descr_input = descr_gen_input(path_test_input) + descr_output = descr_gen_output_flow(path_test_output) + descr_output_downsampled_tensors = descr_gen_output_downsampled(path_dir_temp, nbase=cpnet_biio.nbase[1:]) + descr_output_style_tensor = descr_gen_output_style(path_test_style, cpnet_biio.nbase[-1]) + pytorch_version = Version(torch.__version__) + pytorch_architecture = descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper) + + # Package model + my_model_descr = package_to_bioimageio( + path_pretrained_model, + path_save_trace, + path_readme, + list_path_cover_images, + descr_input, + descr_output, + descr_output_downsampled_tensors, + descr_output_style_tensor, + pytorch_version, + pytorch_architecture, + args.model_id, + args.model_icon, + args.model_version, + args.model_name, + args.model_documentation, + model_authors, + model_cite, + args.model_tags, + args.model_license, + args.model_repo, + ) + + # Test model + summary = test_model(my_model_descr, weight_format="pytorch_state_dict") + summary.display() + summary = test_model(my_model_descr, weight_format="torchscript") + summary.display() + + # Save BioImage.IO package + package_path = save_bioimageio_package(my_model_descr, output_path=Path(path_bioimageio_package)) + print("package path:", package_path) + + +if __name__ == "__main__": + main() diff --git a/models/seg_post_model/cellpose/gui/gui.py b/models/seg_post_model/cellpose/gui/gui.py new file mode 100644 index 0000000000000000000000000000000000000000..99c7fcd7c61851c9e6e68c66c4b4f1ef831aa55f --- /dev/null +++ b/models/seg_post_model/cellpose/gui/gui.py @@ -0,0 +1,2007 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu. +""" + +import sys, os, pathlib, warnings, datetime, time, copy + +from qtpy import QtGui, QtCore +from superqt import QRangeSlider, QCollapsible +from qtpy.QtWidgets import QScrollArea, QMainWindow, QApplication, QWidget, QScrollBar, \ + QComboBox, QGridLayout, QPushButton, QFrame, QCheckBox, QLabel, QProgressBar, \ + QLineEdit, QMessageBox, QGroupBox, QMenu, QAction +import pyqtgraph as pg + +import numpy as np +from scipy.stats import mode +import cv2 + +from . import guiparts, menus, io +from .. import models, core, dynamics, version, train +from ..utils import download_url_to_file, masks_to_outlines, diameters +from ..io import get_image_files, imsave, imread +from ..transforms import resize_image, normalize99, normalize99_tile, smooth_sharpen_img +from ..models import normalize_default +from ..plot import disk + +try: + import matplotlib.pyplot as plt + MATPLOTLIB = True +except: + MATPLOTLIB = False + +Horizontal = QtCore.Qt.Orientation.Horizontal + + +class Slider(QRangeSlider): + + def __init__(self, parent, name, color): + super().__init__(Horizontal) + self.setEnabled(False) + self.valueChanged.connect(lambda: self.levelChanged(parent)) + self.name = name + + self.setStyleSheet(""" QSlider{ + background-color: transparent; + } + """) + self.show() + + def levelChanged(self, parent): + parent.level_change(self.name) + + +class QHLine(QFrame): + + def __init__(self): + super(QHLine, self).__init__() + self.setFrameShape(QFrame.HLine) + self.setLineWidth(8) + + +def make_bwr(): + # make a bwr colormap + b = np.append(255 * np.ones(128), np.linspace(0, 255, 128)[::-1])[:, np.newaxis] + r = np.append(np.linspace(0, 255, 128), 255 * np.ones(128))[:, np.newaxis] + g = np.append(np.linspace(0, 255, 128), + np.linspace(0, 255, 128)[::-1])[:, np.newaxis] + color = np.concatenate((r, g, b), axis=-1).astype(np.uint8) + bwr = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) + return bwr + + +def make_spectral(): + # make spectral colormap + r = np.array([ + 0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60, 64, 68, 72, 76, 80, + 84, 88, 92, 96, 100, 104, 108, 112, 116, 120, 124, 128, 128, 128, 128, 128, 128, + 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 128, 120, 112, 104, 96, 88, + 80, 72, 64, 56, 48, 40, 32, 24, 16, 8, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 7, 11, 15, 19, 23, + 27, 31, 35, 39, 43, 47, 51, 55, 59, 63, 67, 71, 75, 79, 83, 87, 91, 95, 99, 103, + 107, 111, 115, 119, 123, 127, 131, 135, 139, 143, 147, 151, 155, 159, 163, 167, + 171, 175, 179, 183, 187, 191, 195, 199, 203, 207, 211, 215, 219, 223, 227, 231, + 235, 239, 243, 247, 251, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255 + ]) + g = np.array([ + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 9, 9, 8, 8, 7, 7, 6, 6, 5, 5, 5, 4, 4, 3, 3, + 2, 2, 1, 1, 0, 0, 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, + 119, 127, 135, 143, 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, + 247, 255, 247, 239, 231, 223, 215, 207, 199, 191, 183, 175, 167, 159, 151, 143, + 135, 128, 129, 131, 132, 134, 135, 137, 139, 140, 142, 143, 145, 147, 148, 150, + 151, 153, 154, 156, 158, 159, 161, 162, 164, 166, 167, 169, 170, 172, 174, 175, + 177, 178, 180, 181, 183, 185, 186, 188, 189, 191, 193, 194, 196, 197, 199, 201, + 202, 204, 205, 207, 208, 210, 212, 213, 215, 216, 218, 220, 221, 223, 224, 226, + 228, 229, 231, 232, 234, 235, 237, 239, 240, 242, 243, 245, 247, 248, 250, 251, + 253, 255, 251, 247, 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, + 195, 191, 187, 183, 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, + 131, 127, 123, 119, 115, 111, 107, 103, 99, 95, 91, 87, 83, 79, 75, 71, 67, 63, + 59, 55, 51, 47, 43, 39, 35, 31, 27, 23, 19, 15, 11, 7, 3, 0, 8, 16, 24, 32, 41, + 49, 57, 65, 74, 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, + 189, 197, 205, 213, 222, 230, 238, 246, 254 + ]) + b = np.array([ + 0, 7, 15, 23, 31, 39, 47, 55, 63, 71, 79, 87, 95, 103, 111, 119, 127, 135, 143, + 151, 159, 167, 175, 183, 191, 199, 207, 215, 223, 231, 239, 247, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, + 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 255, 251, 247, + 243, 239, 235, 231, 227, 223, 219, 215, 211, 207, 203, 199, 195, 191, 187, 183, + 179, 175, 171, 167, 163, 159, 155, 151, 147, 143, 139, 135, 131, 128, 126, 124, + 122, 120, 118, 116, 114, 112, 110, 108, 106, 104, 102, 100, 98, 96, 94, 92, 90, + 88, 86, 84, 82, 80, 78, 76, 74, 72, 70, 68, 66, 64, 62, 60, 58, 56, 54, 52, 50, + 48, 46, 44, 42, 40, 38, 36, 34, 32, 30, 28, 26, 24, 22, 20, 18, 16, 14, 12, 10, + 8, 6, 4, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 8, 16, 24, 32, 41, 49, 57, 65, 74, + 82, 90, 98, 106, 115, 123, 131, 139, 148, 156, 164, 172, 180, 189, 197, 205, + 213, 222, 230, 238, 246, 254 + ]) + color = (np.vstack((r, g, b)).T).astype(np.uint8) + spectral = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) + return spectral + + +def make_cmap(cm=0): + # make a single channel colormap + r = np.arange(0, 256) + color = np.zeros((256, 3)) + color[:, cm] = r + color = color.astype(np.uint8) + cmap = pg.ColorMap(pos=np.linspace(0.0, 255, 256), color=color) + return cmap + + +def run(image=None): + from ..io import logger_setup + logger, log_file = logger_setup() + # Always start by initializing Qt (only once per application) + warnings.filterwarnings("ignore") + app = QApplication(sys.argv) + icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png") + guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png") + if not icon_path.is_file(): + cp_dir = pathlib.Path.home().joinpath(".cellpose") + cp_dir.mkdir(exist_ok=True) + print("downloading logo") + download_url_to_file( + "https://www.cellpose.org/static/images/cellpose_transparent.png", + icon_path, progress=True) + if not guip_path.is_file(): + print("downloading help window image") + download_url_to_file("https://www.cellpose.org/static/images/cellposeSAM_gui.png", + guip_path, progress=True) + icon_path = str(icon_path.resolve()) + app_icon = QtGui.QIcon() + app_icon.addFile(icon_path, QtCore.QSize(16, 16)) + app_icon.addFile(icon_path, QtCore.QSize(24, 24)) + app_icon.addFile(icon_path, QtCore.QSize(32, 32)) + app_icon.addFile(icon_path, QtCore.QSize(48, 48)) + app_icon.addFile(icon_path, QtCore.QSize(64, 64)) + app_icon.addFile(icon_path, QtCore.QSize(256, 256)) + app.setWindowIcon(app_icon) + app.setStyle("Fusion") + app.setPalette(guiparts.DarkPalette()) + MainW(image=image, logger=logger) + ret = app.exec_() + sys.exit(ret) + + +class MainW(QMainWindow): + + def __init__(self, image=None, logger=None): + super(MainW, self).__init__() + + self.logger = logger + pg.setConfigOptions(imageAxisOrder="row-major") + self.setGeometry(50, 50, 1200, 1000) + self.setWindowTitle(f"cellpose v{version}") + self.cp_path = os.path.dirname(os.path.realpath(__file__)) + app_icon = QtGui.QIcon() + icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png") + icon_path = str(icon_path.resolve()) + app_icon.addFile(icon_path, QtCore.QSize(16, 16)) + app_icon.addFile(icon_path, QtCore.QSize(24, 24)) + app_icon.addFile(icon_path, QtCore.QSize(32, 32)) + app_icon.addFile(icon_path, QtCore.QSize(48, 48)) + app_icon.addFile(icon_path, QtCore.QSize(64, 64)) + app_icon.addFile(icon_path, QtCore.QSize(256, 256)) + self.setWindowIcon(app_icon) + # rgb(150,255,150) + self.setStyleSheet(guiparts.stylesheet()) + + menus.mainmenu(self) + menus.editmenu(self) + menus.modelmenu(self) + menus.helpmenu(self) + + self.stylePressed = """QPushButton {Text-align: center; + background-color: rgb(150,50,150); + border-color: white; + color:white;} + QToolTip { + background-color: black; + color: white; + border: black solid 1px + }""" + self.styleUnpressed = """QPushButton {Text-align: center; + background-color: rgb(50,50,50); + border-color: white; + color:white;} + QToolTip { + background-color: black; + color: white; + border: black solid 1px + }""" + self.loaded = False + + # ---- MAIN WIDGET LAYOUT ---- # + self.cwidget = QWidget(self) + self.lmain = QGridLayout() + self.cwidget.setLayout(self.lmain) + self.setCentralWidget(self.cwidget) + self.lmain.setVerticalSpacing(0) + self.lmain.setContentsMargins(0, 0, 0, 10) + + self.imask = 0 + self.scrollarea = QScrollArea() + self.scrollarea.setVerticalScrollBarPolicy(QtCore.Qt.ScrollBarAlwaysOn) + self.scrollarea.setStyleSheet("""QScrollArea { border: none }""") + self.scrollarea.setWidgetResizable(True) + self.swidget = QWidget(self) + self.scrollarea.setWidget(self.swidget) + self.l0 = QGridLayout() + self.swidget.setLayout(self.l0) + b = self.make_buttons() + self.lmain.addWidget(self.scrollarea, 0, 0, 39, 9) + + # ---- drawing area ---- # + self.win = pg.GraphicsLayoutWidget() + + self.lmain.addWidget(self.win, 0, 9, 40, 30) + + self.win.scene().sigMouseClicked.connect(self.plot_clicked) + self.win.scene().sigMouseMoved.connect(self.mouse_moved) + self.make_viewbox() + self.lmain.setColumnStretch(10, 1) + bwrmap = make_bwr() + self.bwr = bwrmap.getLookupTable(start=0.0, stop=255.0, alpha=False) + self.cmap = [] + # spectral colormap + self.cmap.append(make_spectral().getLookupTable(start=0.0, stop=255.0, + alpha=False)) + # single channel colormaps + for i in range(3): + self.cmap.append( + make_cmap(i).getLookupTable(start=0.0, stop=255.0, alpha=False)) + + if MATPLOTLIB: + self.colormap = (plt.get_cmap("gist_ncar")(np.linspace(0.0, .9, 1000000)) * + 255).astype(np.uint8) + np.random.seed(42) # make colors stable + self.colormap = self.colormap[np.random.permutation(1000000)] + else: + np.random.seed(42) # make colors stable + self.colormap = ((np.random.rand(1000000, 3) * 0.8 + 0.1) * 255).astype( + np.uint8) + self.NZ = 1 + self.restore = None + self.ratio = 1. + self.reset() + + # This needs to go after .reset() is called to get state fully set up: + self.autobtn.checkStateChanged.connect(self.compute_saturation_if_checked) + + self.load_3D = False + + # if called with image, load it + if image is not None: + self.filename = image + io._load_image(self, self.filename) + + # training settings + d = datetime.datetime.now() + self.training_params = { + "model_index": 0, + "learning_rate": 1e-5, + "weight_decay": 0.1, + "n_epochs": 100, + "model_name": "cpsam" + d.strftime("_%Y%m%d_%H%M%S"), + } + + self.stitch_threshold = 0. + self.flow3D_smooth = 0. + self.anisotropy = 1. + self.min_size = 15 + + self.setAcceptDrops(True) + self.win.show() + self.show() + + def help_window(self): + HW = guiparts.HelpWindow(self) + HW.show() + + def train_help_window(self): + THW = guiparts.TrainHelpWindow(self) + THW.show() + + def gui_window(self): + EG = guiparts.ExampleGUI(self) + EG.show() + + def make_buttons(self): + self.boldfont = QtGui.QFont("Arial", 11, QtGui.QFont.Bold) + self.boldmedfont = QtGui.QFont("Arial", 9, QtGui.QFont.Bold) + self.medfont = QtGui.QFont("Arial", 9) + self.smallfont = QtGui.QFont("Arial", 8) + + b = 0 + self.satBox = QGroupBox("Views") + self.satBox.setFont(self.boldfont) + self.satBoxG = QGridLayout() + self.satBox.setLayout(self.satBoxG) + self.l0.addWidget(self.satBox, b, 0, 1, 9) + + widget_row = 0 + self.view = 0 # 0=image, 1=flowsXY, 2=flowsZ, 3=cellprob + self.color = 0 # 0=RGB, 1=gray, 2=R, 3=G, 4=B + self.RGBDropDown = QComboBox() + self.RGBDropDown.addItems( + ["RGB", "red=R", "green=G", "blue=B", "gray", "spectral"]) + self.RGBDropDown.setFont(self.medfont) + self.RGBDropDown.currentIndexChanged.connect(self.color_choose) + self.satBoxG.addWidget(self.RGBDropDown, widget_row, 0, 1, 3) + + label = QLabel("

[↑ / ↓ or W/S]

") + label.setFont(self.smallfont) + self.satBoxG.addWidget(label, widget_row, 3, 1, 3) + label = QLabel("[R / G / B \n toggles color ]") + label.setFont(self.smallfont) + self.satBoxG.addWidget(label, widget_row, 6, 1, 3) + + widget_row += 1 + self.ViewDropDown = QComboBox() + self.ViewDropDown.addItems(["image", "gradXY", "cellprob", "restored"]) + self.ViewDropDown.setFont(self.medfont) + self.ViewDropDown.model().item(3).setEnabled(False) + self.ViewDropDown.currentIndexChanged.connect(self.update_plot) + self.satBoxG.addWidget(self.ViewDropDown, widget_row, 0, 2, 3) + + label = QLabel("[pageup / pagedown]") + label.setFont(self.smallfont) + self.satBoxG.addWidget(label, widget_row, 3, 1, 5) + + widget_row += 2 + label = QLabel("") + label.setToolTip( + "NOTE: manually changing the saturation bars does not affect normalization in segmentation" + ) + self.satBoxG.addWidget(label, widget_row, 0, 1, 5) + + self.autobtn = QCheckBox("auto-adjust saturation") + self.autobtn.setToolTip("sets scale-bars as normalized for segmentation") + self.autobtn.setFont(self.medfont) + self.autobtn.setChecked(True) + self.satBoxG.addWidget(self.autobtn, widget_row, 1, 1, 8) + + widget_row += 1 + self.sliders = [] + colors = [[255, 0, 0], [0, 255, 0], [0, 0, 255], [100, 100, 100]] + colornames = ["red", "Chartreuse", "DodgerBlue"] + names = ["red", "green", "blue"] + for r in range(3): + widget_row += 1 + if r == 0: + label = QLabel('gray/
red') + else: + label = QLabel(names[r] + ":") + label.setStyleSheet(f"color: {colornames[r]}") + label.setFont(self.boldmedfont) + self.satBoxG.addWidget(label, widget_row, 0, 1, 2) + self.sliders.append(Slider(self, names[r], colors[r])) + self.sliders[-1].setMinimum(-.1) + self.sliders[-1].setMaximum(255.1) + self.sliders[-1].setValue([0, 255]) + self.sliders[-1].setToolTip( + "NOTE: manually changing the saturation bars does not affect normalization in segmentation" + ) + self.satBoxG.addWidget(self.sliders[-1], widget_row, 2, 1, 7) + + b += 1 + self.drawBox = QGroupBox("Drawing") + self.drawBox.setFont(self.boldfont) + self.drawBoxG = QGridLayout() + self.drawBox.setLayout(self.drawBoxG) + self.l0.addWidget(self.drawBox, b, 0, 1, 9) + self.autosave = True + + widget_row = 0 + self.brush_size = 3 + self.BrushChoose = QComboBox() + self.BrushChoose.addItems(["1", "3", "5", "7", "9"]) + self.BrushChoose.currentIndexChanged.connect(self.brush_choose) + self.BrushChoose.setFixedWidth(40) + self.BrushChoose.setFont(self.medfont) + self.drawBoxG.addWidget(self.BrushChoose, widget_row, 3, 1, 2) + label = QLabel("brush size:") + label.setFont(self.medfont) + self.drawBoxG.addWidget(label, widget_row, 0, 1, 3) + + widget_row += 1 + # turn off masks + self.layer_off = False + self.masksOn = True + self.MCheckBox = QCheckBox("MASKS ON [X]") + self.MCheckBox.setFont(self.medfont) + self.MCheckBox.setChecked(True) + self.MCheckBox.toggled.connect(self.toggle_masks) + self.drawBoxG.addWidget(self.MCheckBox, widget_row, 0, 1, 5) + + widget_row += 1 + # turn off outlines + self.outlinesOn = False # turn off by default + self.OCheckBox = QCheckBox("outlines on [Z]") + self.OCheckBox.setFont(self.medfont) + self.drawBoxG.addWidget(self.OCheckBox, widget_row, 0, 1, 5) + self.OCheckBox.setChecked(False) + self.OCheckBox.toggled.connect(self.toggle_masks) + + widget_row += 1 + self.SCheckBox = QCheckBox("single stroke") + self.SCheckBox.setFont(self.medfont) + self.SCheckBox.setChecked(True) + self.SCheckBox.toggled.connect(self.autosave_on) + self.SCheckBox.setEnabled(True) + self.drawBoxG.addWidget(self.SCheckBox, widget_row, 0, 1, 5) + + # buttons for deleting multiple cells + self.deleteBox = QGroupBox("delete multiple ROIs") + self.deleteBox.setStyleSheet("color: rgb(200, 200, 200)") + self.deleteBox.setFont(self.medfont) + self.deleteBoxG = QGridLayout() + self.deleteBox.setLayout(self.deleteBoxG) + self.drawBoxG.addWidget(self.deleteBox, 0, 5, 4, 4) + self.MakeDeletionRegionButton = QPushButton("region-select") + self.MakeDeletionRegionButton.clicked.connect(self.remove_region_cells) + self.deleteBoxG.addWidget(self.MakeDeletionRegionButton, 0, 0, 1, 4) + self.MakeDeletionRegionButton.setFont(self.smallfont) + self.MakeDeletionRegionButton.setFixedWidth(70) + self.DeleteMultipleROIButton = QPushButton("click-select") + self.DeleteMultipleROIButton.clicked.connect(self.delete_multiple_cells) + self.deleteBoxG.addWidget(self.DeleteMultipleROIButton, 1, 0, 1, 4) + self.DeleteMultipleROIButton.setFont(self.smallfont) + self.DeleteMultipleROIButton.setFixedWidth(70) + self.DoneDeleteMultipleROIButton = QPushButton("done") + self.DoneDeleteMultipleROIButton.clicked.connect( + self.done_remove_multiple_cells) + self.deleteBoxG.addWidget(self.DoneDeleteMultipleROIButton, 2, 0, 1, 2) + self.DoneDeleteMultipleROIButton.setFont(self.smallfont) + self.DoneDeleteMultipleROIButton.setFixedWidth(35) + self.CancelDeleteMultipleROIButton = QPushButton("cancel") + self.CancelDeleteMultipleROIButton.clicked.connect(self.cancel_remove_multiple) + self.deleteBoxG.addWidget(self.CancelDeleteMultipleROIButton, 2, 2, 1, 2) + self.CancelDeleteMultipleROIButton.setFont(self.smallfont) + self.CancelDeleteMultipleROIButton.setFixedWidth(35) + + b += 1 + widget_row = 0 + self.segBox = QGroupBox("Segmentation") + self.segBoxG = QGridLayout() + self.segBox.setLayout(self.segBoxG) + self.l0.addWidget(self.segBox, b, 0, 1, 9) + self.segBox.setFont(self.boldfont) + + widget_row += 1 + + # use GPU + self.useGPU = QCheckBox("use GPU") + self.useGPU.setToolTip( + "if you have specially installed the cuda version of torch, then you can activate this" + ) + self.useGPU.setFont(self.medfont) + self.check_gpu() + self.segBoxG.addWidget(self.useGPU, widget_row, 0, 1, 3) + + # compute segmentation with general models + self.net_text = ["run CPSAM"] + nett = ["cellpose super-generalist model"] + + self.StyleButtons = [] + jj = 4 + for j in range(len(self.net_text)): + self.StyleButtons.append( + guiparts.ModelButton(self, self.net_text[j], self.net_text[j])) + w = 5 + self.segBoxG.addWidget(self.StyleButtons[-1], widget_row, jj, 1, w) + jj += w + self.StyleButtons[-1].setToolTip(nett[j]) + + widget_row += 1 + self.ncells = guiparts.ObservableVariable(0) + self.roi_count = QLabel() + self.roi_count.setFont(self.boldfont) + self.roi_count.setAlignment(QtCore.Qt.AlignLeft) + self.ncells.valueChanged.connect( + lambda n: self.roi_count.setText(f'{str(n)} ROIs') + ) + + self.segBoxG.addWidget(self.roi_count, widget_row, 0, 1, 4) + + self.progress = QProgressBar(self) + self.segBoxG.addWidget(self.progress, widget_row, 4, 1, 5) + + widget_row += 1 + + ############################### Segmentation settings ############################### + self.additional_seg_settings_qcollapsible = QCollapsible("additional settings") + self.additional_seg_settings_qcollapsible.setFont(self.medfont) + self.additional_seg_settings_qcollapsible._toggle_btn.setFont(self.medfont) + self.segmentation_settings = guiparts.SegmentationSettings(self.medfont) + self.additional_seg_settings_qcollapsible.setContent(self.segmentation_settings) + self.segBoxG.addWidget(self.additional_seg_settings_qcollapsible, widget_row, 0, 1, 9) + + # connect edits to image processing steps: + self.segmentation_settings.diameter_box.editingFinished.connect(self.update_scale) + self.segmentation_settings.flow_threshold_box.returnPressed.connect(self.compute_cprob) + self.segmentation_settings.cellprob_threshold_box.returnPressed.connect(self.compute_cprob) + self.segmentation_settings.niter_box.returnPressed.connect(self.compute_cprob) + + # Needed to do this for the drop down to not be open on startup + self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(True) + self.additional_seg_settings_qcollapsible._toggle_btn.setChecked(False) + + b += 1 + self.modelBox = QGroupBox("user-trained models") + self.modelBoxG = QGridLayout() + self.modelBox.setLayout(self.modelBoxG) + self.l0.addWidget(self.modelBox, b, 0, 1, 9) + self.modelBox.setFont(self.boldfont) + # choose models + self.ModelChooseC = QComboBox() + self.ModelChooseC.setFont(self.medfont) + current_index = 0 + self.ModelChooseC.addItems(["custom models"]) + if len(self.model_strings) > 0: + self.ModelChooseC.addItems(self.model_strings) + self.ModelChooseC.setFixedWidth(175) + self.ModelChooseC.setCurrentIndex(current_index) + tipstr = 'add or train your own models in the "Models" file menu and choose model here' + self.ModelChooseC.setToolTip(tipstr) + self.ModelChooseC.activated.connect(lambda: self.model_choose(custom=True)) + self.modelBoxG.addWidget(self.ModelChooseC, widget_row, 0, 1, 8) + + # compute segmentation w/ custom model + self.ModelButtonC = QPushButton(u"run") + self.ModelButtonC.setFont(self.medfont) + self.ModelButtonC.setFixedWidth(35) + self.ModelButtonC.clicked.connect( + lambda: self.compute_segmentation(custom=True)) + self.modelBoxG.addWidget(self.ModelButtonC, widget_row, 8, 1, 1) + self.ModelButtonC.setEnabled(False) + + + b += 1 + self.filterBox = QGroupBox("Image filtering") + self.filterBox.setFont(self.boldfont) + self.filterBox_grid_layout = QGridLayout() + self.filterBox.setLayout(self.filterBox_grid_layout) + self.l0.addWidget(self.filterBox, b, 0, 1, 9) + + widget_row = 0 + + # Filtering + self.FilterButtons = [] + nett = [ + "clear restore/filter", + "filter image (settings below)", + ] + self.filter_text = ["none", + "filter", + ] + self.restore = None + self.ratio = 1. + jj = 0 + w = 3 + for j in range(len(self.filter_text)): + self.FilterButtons.append( + guiparts.FilterButton(self, self.filter_text[j])) + self.filterBox_grid_layout.addWidget(self.FilterButtons[-1], widget_row, jj, 1, w) + self.FilterButtons[-1].setFixedWidth(75) + self.FilterButtons[-1].setToolTip(nett[j]) + self.FilterButtons[-1].setFont(self.medfont) + widget_row += 1 if j%2==1 else 0 + jj = 0 if j%2==1 else jj + w + + self.save_norm = QCheckBox("save restored/filtered image") + self.save_norm.setFont(self.medfont) + self.save_norm.setToolTip("save restored/filtered image in _seg.npy file") + self.save_norm.setChecked(True) + + widget_row += 2 + + self.filtBox = QCollapsible("custom filter settings") + self.filtBox._toggle_btn.setFont(self.medfont) + self.filtBoxG = QGridLayout() + _content = QWidget() + _content.setLayout(self.filtBoxG) + _content.setMaximumHeight(0) + _content.setMinimumHeight(0) + self.filtBox.setContent(_content) + self.filterBox_grid_layout.addWidget(self.filtBox, widget_row, 0, 1, 9) + + self.filt_vals = [0., 0., 0., 0.] + self.filt_edits = [] + labels = [ + "sharpen\nradius", "smooth\nradius", "tile_norm\nblocksize", + "tile_norm\nsmooth3D" + ] + tooltips = [ + "set size of surround-subtraction filter for sharpening image", + "set size of gaussian filter for smoothing image", + "set size of tiles to use to normalize image", + "set amount of smoothing of normalization values across planes" + ] + + for p in range(4): + label = QLabel(f"{labels[p]}:") + label.setToolTip(tooltips[p]) + label.setFont(self.medfont) + self.filtBoxG.addWidget(label, widget_row + p // 2, 4 * (p % 2), 1, 2) + self.filt_edits.append(QLineEdit()) + self.filt_edits[p].setText(str(self.filt_vals[p])) + self.filt_edits[p].setFixedWidth(40) + self.filt_edits[p].setFont(self.medfont) + self.filtBoxG.addWidget(self.filt_edits[p], widget_row + p // 2, 4 * (p % 2) + 2, 1, + 2) + self.filt_edits[p].setToolTip(tooltips[p]) + + widget_row += 3 + self.norm3D_cb = QCheckBox("norm3D") + self.norm3D_cb.setFont(self.medfont) + self.norm3D_cb.setChecked(True) + self.norm3D_cb.setToolTip("run same normalization across planes") + self.filtBoxG.addWidget(self.norm3D_cb, widget_row, 0, 1, 3) + + + return b + + def level_change(self, r): + r = ["red", "green", "blue"].index(r) + if self.loaded: + sval = self.sliders[r].value() + self.saturation[r][self.currentZ] = sval + if not self.autobtn.isChecked(): + for r in range(3): + for i in range(len(self.saturation[r])): + self.saturation[r][i] = self.saturation[r][self.currentZ] + self.update_plot() + + def keyPressEvent(self, event): + if self.loaded: + if not (event.modifiers() & + (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier | + QtCore.Qt.AltModifier) or self.in_stroke): + updated = False + if len(self.current_point_set) > 0: + if event.key() == QtCore.Qt.Key_Return: + self.add_set() + else: + nviews = self.ViewDropDown.count() - 1 + nviews += int( + self.ViewDropDown.model().item(self.ViewDropDown.count() - + 1).isEnabled()) + if event.key() == QtCore.Qt.Key_X: + self.MCheckBox.toggle() + if event.key() == QtCore.Qt.Key_Z: + self.OCheckBox.toggle() + if event.key() == QtCore.Qt.Key_Left or event.key( + ) == QtCore.Qt.Key_A: + self.get_prev_image() + elif event.key() == QtCore.Qt.Key_Right or event.key( + ) == QtCore.Qt.Key_D: + self.get_next_image() + elif event.key() == QtCore.Qt.Key_PageDown: + self.view = (self.view + 1) % (nviews) + self.ViewDropDown.setCurrentIndex(self.view) + elif event.key() == QtCore.Qt.Key_PageUp: + self.view = (self.view - 1) % (nviews) + self.ViewDropDown.setCurrentIndex(self.view) + + # can change background or stroke size if cell not finished + if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W: + self.color = (self.color - 1) % (6) + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_Down or event.key( + ) == QtCore.Qt.Key_S: + self.color = (self.color + 1) % (6) + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_R: + if self.color != 1: + self.color = 1 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_G: + if self.color != 2: + self.color = 2 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_B: + if self.color != 3: + self.color = 3 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif (event.key() == QtCore.Qt.Key_Comma or + event.key() == QtCore.Qt.Key_Period): + count = self.BrushChoose.count() + gci = self.BrushChoose.currentIndex() + if event.key() == QtCore.Qt.Key_Comma: + gci = max(0, gci - 1) + else: + gci = min(count - 1, gci + 1) + self.BrushChoose.setCurrentIndex(gci) + self.brush_choose() + if not updated: + self.update_plot() + if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal: + self.p0.keyPressEvent(event) + + def autosave_on(self): + if self.SCheckBox.isChecked(): + self.autosave = True + else: + self.autosave = False + + def check_gpu(self, torch=True): + # also decide whether or not to use torch + self.useGPU.setChecked(False) + self.useGPU.setEnabled(False) + if core.use_gpu(use_torch=True): + self.useGPU.setEnabled(True) + self.useGPU.setChecked(True) + else: + self.useGPU.setStyleSheet("color: rgb(80,80,80);") + + + def model_choose(self, custom=False): + index = self.ModelChooseC.currentIndex( + ) if custom else self.ModelChooseB.currentIndex() + if index > 0: + if custom: + model_name = self.ModelChooseC.currentText() + else: + model_name = self.net_names[index - 1] + print(f"GUI_INFO: selected model {model_name}, loading now") + self.initialize_model(model_name=model_name, custom=custom) + + def toggle_scale(self): + if self.scale_on: + self.p0.removeItem(self.scale) + self.scale_on = False + else: + self.p0.addItem(self.scale) + self.scale_on = True + + def enable_buttons(self): + if len(self.model_strings) > 0: + self.ModelButtonC.setEnabled(True) + for i in range(len(self.StyleButtons)): + self.StyleButtons[i].setEnabled(True) + + for i in range(len(self.FilterButtons)): + self.FilterButtons[i].setEnabled(True) + if self.load_3D: + self.FilterButtons[-2].setEnabled(False) + + self.newmodel.setEnabled(True) + self.loadMasks.setEnabled(True) + + for n in range(self.nchan): + self.sliders[n].setEnabled(True) + for n in range(self.nchan, 3): + self.sliders[n].setEnabled(True) + + self.toggle_mask_ops() + + self.update_plot() + self.setWindowTitle(self.filename) + + def disable_buttons_removeROIs(self): + if len(self.model_strings) > 0: + self.ModelButtonC.setEnabled(False) + for i in range(len(self.StyleButtons)): + self.StyleButtons[i].setEnabled(False) + self.newmodel.setEnabled(False) + self.loadMasks.setEnabled(False) + self.saveSet.setEnabled(False) + self.savePNG.setEnabled(False) + self.saveFlows.setEnabled(False) + self.saveOutlines.setEnabled(False) + self.saveROIs.setEnabled(False) + + self.MakeDeletionRegionButton.setEnabled(False) + self.DeleteMultipleROIButton.setEnabled(False) + self.DoneDeleteMultipleROIButton.setEnabled(True) + self.CancelDeleteMultipleROIButton.setEnabled(True) + + def toggle_mask_ops(self): + self.update_layer() + self.toggle_saving() + self.toggle_removals() + + def toggle_saving(self): + if self.ncells > 0: + self.saveSet.setEnabled(True) + self.savePNG.setEnabled(True) + self.saveFlows.setEnabled(True) + self.saveOutlines.setEnabled(True) + self.saveROIs.setEnabled(True) + else: + self.saveSet.setEnabled(False) + self.savePNG.setEnabled(False) + self.saveFlows.setEnabled(False) + self.saveOutlines.setEnabled(False) + self.saveROIs.setEnabled(False) + + def toggle_removals(self): + if self.ncells > 0: + self.ClearButton.setEnabled(True) + self.remcell.setEnabled(True) + self.undo.setEnabled(True) + self.MakeDeletionRegionButton.setEnabled(True) + self.DeleteMultipleROIButton.setEnabled(True) + self.DoneDeleteMultipleROIButton.setEnabled(False) + self.CancelDeleteMultipleROIButton.setEnabled(False) + else: + self.ClearButton.setEnabled(False) + self.remcell.setEnabled(False) + self.undo.setEnabled(False) + self.MakeDeletionRegionButton.setEnabled(False) + self.DeleteMultipleROIButton.setEnabled(False) + self.DoneDeleteMultipleROIButton.setEnabled(False) + self.CancelDeleteMultipleROIButton.setEnabled(False) + + def remove_action(self): + if self.selected > 0: + self.remove_cell(self.selected) + + def undo_action(self): + if (len(self.strokes) > 0 and self.strokes[-1][0][0] == self.currentZ): + self.remove_stroke() + else: + # remove previous cell + if self.ncells > 0: + self.remove_cell(self.ncells.get()) + + def undo_remove_action(self): + self.undo_remove_cell() + + def get_files(self): + folder = os.path.dirname(self.filename) + mask_filter = "_masks" + images = get_image_files(folder, mask_filter) + fnames = [os.path.split(images[k])[-1] for k in range(len(images))] + f0 = os.path.split(self.filename)[-1] + idx = np.nonzero(np.array(fnames) == f0)[0][0] + return images, idx + + def get_prev_image(self): + images, idx = self.get_files() + idx = (idx - 1) % len(images) + io._load_image(self, filename=images[idx]) + + def get_next_image(self, load_seg=True): + images, idx = self.get_files() + idx = (idx + 1) % len(images) + io._load_image(self, filename=images[idx], load_seg=load_seg) + + def dragEnterEvent(self, event): + if event.mimeData().hasUrls(): + event.accept() + else: + event.ignore() + + def dropEvent(self, event): + files = [u.toLocalFile() for u in event.mimeData().urls()] + if os.path.splitext(files[0])[-1] == ".npy": + io._load_seg(self, filename=files[0], load_3D=self.load_3D) + else: + io._load_image(self, filename=files[0], load_seg=True, load_3D=self.load_3D) + + def toggle_masks(self): + if self.MCheckBox.isChecked(): + self.masksOn = True + else: + self.masksOn = False + if self.OCheckBox.isChecked(): + self.outlinesOn = True + else: + self.outlinesOn = False + if not self.masksOn and not self.outlinesOn: + self.p0.removeItem(self.layer) + self.layer_off = True + else: + if self.layer_off: + self.p0.addItem(self.layer) + self.draw_layer() + self.update_layer() + if self.loaded: + self.update_plot() + self.update_layer() + + def make_viewbox(self): + self.p0 = guiparts.ViewBoxNoRightDrag(parent=self, lockAspect=True, + name="plot1", border=[100, 100, + 100], invertY=True) + self.p0.setCursor(QtCore.Qt.CrossCursor) + self.brush_size = 3 + self.win.addItem(self.p0, 0, 0, rowspan=1, colspan=1) + self.p0.setMenuEnabled(False) + self.p0.setMouseEnabled(x=True, y=True) + self.img = pg.ImageItem(viewbox=self.p0, parent=self) + self.img.autoDownsample = False + self.layer = guiparts.ImageDraw(viewbox=self.p0, parent=self) + self.layer.setLevels([0, 255]) + self.scale = pg.ImageItem(viewbox=self.p0, parent=self) + self.scale.setLevels([0, 255]) + self.p0.scene().contextMenuItem = self.p0 + self.Ly, self.Lx = 512, 512 + self.p0.addItem(self.img) + self.p0.addItem(self.layer) + self.p0.addItem(self.scale) + + def reset(self): + # ---- start sets of points ---- # + self.selected = 0 + self.nchan = 3 + self.loaded = False + self.channel = [0, 1] + self.current_point_set = [] + self.in_stroke = False + self.strokes = [] + self.stroke_appended = True + self.resize = False + self.ncells.reset() + self.zdraw = [] + self.removed_cell = [] + self.cellcolors = np.array([255, 255, 255])[np.newaxis, :] + + # -- zero out image stack -- # + self.opacity = 128 # how opaque masks should be + self.outcolor = [200, 200, 255, 200] + self.NZ, self.Ly, self.Lx = 1, 256, 256 + self.saturation = self.saturation if hasattr(self, 'saturation') else [] + + # only adjust the saturation if auto-adjust is on: + if self.autobtn.isChecked(): + for r in range(3): + self.saturation.append([[0, 255] for n in range(self.NZ)]) + self.sliders[r].setValue([0, 255]) + self.sliders[r].setEnabled(False) + self.sliders[r].show() + self.currentZ = 0 + self.flows = [[], [], [], [], [[]]] + # masks matrix + # image matrix with a scale disk + self.stack = np.zeros((1, self.Ly, self.Lx, 3)) + self.Lyr, self.Lxr = self.Ly, self.Lx + self.Ly0, self.Lx0 = self.Ly, self.Lx + self.radii = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) + self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) + self.cellpix = np.zeros((1, self.Ly, self.Lx), np.uint16) + self.outpix = np.zeros((1, self.Ly, self.Lx), np.uint16) + self.ismanual = np.zeros(0, "bool") + + # -- set menus to default -- # + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + self.view = 0 + self.ViewDropDown.setCurrentIndex(0) + self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False) + self.delete_restore() + + self.clear_all() + + self.filename = [] + self.loaded = False + self.recompute_masks = False + + self.deleting_multiple = False + self.removing_cells_list = [] + self.removing_region = False + self.remove_roi_obj = None + + def delete_restore(self): + """ delete restored imgs but don't reset settings """ + if hasattr(self, "stack_filtered"): + del self.stack_filtered + if hasattr(self, "cellpix_orig"): + self.cellpix = self.cellpix_orig.copy() + self.outpix = self.outpix_orig.copy() + del self.outpix_orig, self.outpix_resize + del self.cellpix_orig, self.cellpix_resize + + def clear_restore(self): + """ delete restored imgs and reset settings """ + print("GUI_INFO: clearing restored image") + self.ViewDropDown.model().item(self.ViewDropDown.count() - 1).setEnabled(False) + if self.ViewDropDown.currentIndex() == self.ViewDropDown.count() - 1: + self.ViewDropDown.setCurrentIndex(0) + self.delete_restore() + self.restore = None + self.ratio = 1. + self.set_normalize_params(self.get_normalize_params()) + + def brush_choose(self): + self.brush_size = self.BrushChoose.currentIndex() * 2 + 1 + if self.loaded: + self.layer.setDrawKernel(kernel_size=self.brush_size) + self.update_layer() + + def clear_all(self): + self.prev_selected = 0 + self.selected = 0 + if self.restore and "upsample" in self.restore: + self.layerz = 0 * np.ones((self.Lyr, self.Lxr, 4), np.uint8) + self.cellpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16) + self.outpix = np.zeros((self.NZ, self.Lyr, self.Lxr), np.uint16) + self.cellpix_resize = self.cellpix.copy() + self.outpix_resize = self.outpix.copy() + self.cellpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16) + self.outpix_orig = np.zeros((self.NZ, self.Ly0, self.Lx0), np.uint16) + else: + self.layerz = 0 * np.ones((self.Ly, self.Lx, 4), np.uint8) + self.cellpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16) + self.outpix = np.zeros((self.NZ, self.Ly, self.Lx), np.uint16) + + self.cellcolors = np.array([255, 255, 255])[np.newaxis, :] + self.ncells.reset() + self.toggle_removals() + self.update_scale() + self.update_layer() + + def select_cell(self, idx): + self.prev_selected = self.selected + self.selected = idx + if self.selected > 0: + z = self.currentZ + self.layerz[self.cellpix[z] == idx] = np.array( + [255, 255, 255, self.opacity]) + self.update_layer() + + def select_cell_multi(self, idx): + if idx > 0: + z = self.currentZ + self.layerz[self.cellpix[z] == idx] = np.array( + [255, 255, 255, self.opacity]) + self.update_layer() + + def unselect_cell(self): + if self.selected > 0: + idx = self.selected + if idx < (self.ncells.get() + 1): + z = self.currentZ + self.layerz[self.cellpix[z] == idx] = np.append( + self.cellcolors[idx], self.opacity) + if self.outlinesOn: + self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype( + np.uint8) + #[0,0,0,self.opacity]) + self.update_layer() + self.selected = 0 + + def unselect_cell_multi(self, idx): + z = self.currentZ + self.layerz[self.cellpix[z] == idx] = np.append(self.cellcolors[idx], + self.opacity) + if self.outlinesOn: + self.layerz[self.outpix[z] == idx] = np.array(self.outcolor).astype( + np.uint8) + # [0,0,0,self.opacity]) + self.update_layer() + + def remove_cell(self, idx): + if isinstance(idx, (int, np.integer)): + idx = [idx] + # because the function remove_single_cell updates the state of the cellpix and outpix arrays + # by reindexing cells to avoid gaps in the indices, we need to remove the cells in reverse order + # so that the indices are correct + idx.sort(reverse=True) + for i in idx: + self.remove_single_cell(i) + self.ncells -= len(idx) # _save_sets uses ncells + self.update_layer() + + if self.ncells == 0: + self.ClearButton.setEnabled(False) + if self.NZ == 1: + io._save_sets_with_check(self) + + + def remove_single_cell(self, idx): + # remove from manual array + self.selected = 0 + if self.NZ > 1: + zextent = ((self.cellpix == idx).sum(axis=(1, 2)) > 0).nonzero()[0] + else: + zextent = [0] + for z in zextent: + cp = self.cellpix[z] == idx + op = self.outpix[z] == idx + # remove from self.cellpix and self.outpix + self.cellpix[z, cp] = 0 + self.outpix[z, op] = 0 + if z == self.currentZ: + # remove from mask layer + self.layerz[cp] = np.array([0, 0, 0, 0]) + + # reduce other pixels by -1 + self.cellpix[self.cellpix > idx] -= 1 + self.outpix[self.outpix > idx] -= 1 + + if self.NZ == 1: + self.removed_cell = [ + self.ismanual[idx - 1], self.cellcolors[idx], + np.nonzero(cp), + np.nonzero(op) + ] + self.redo.setEnabled(True) + ar, ac = self.removed_cell[2] + d = datetime.datetime.now() + self.track_changes.append( + [d.strftime("%m/%d/%Y, %H:%M:%S"), "removed mask", [ar, ac]]) + # remove cell from lists + self.ismanual = np.delete(self.ismanual, idx - 1) + self.cellcolors = np.delete(self.cellcolors, [idx], axis=0) + del self.zdraw[idx - 1] + print("GUI_INFO: removed cell %d" % (idx - 1)) + + def remove_region_cells(self): + if self.removing_cells_list: + for idx in self.removing_cells_list: + self.unselect_cell_multi(idx) + self.removing_cells_list.clear() + self.disable_buttons_removeROIs() + self.removing_region = True + + self.clear_multi_selected_cells() + + # make roi region here in center of view, making ROI half the size of the view + roi_width = self.p0.viewRect().width() / 2 + x_loc = self.p0.viewRect().x() + (roi_width / 2) + roi_height = self.p0.viewRect().height() / 2 + y_loc = self.p0.viewRect().y() + (roi_height / 2) + + pos = [x_loc, y_loc] + roi = pg.RectROI(pos, [roi_width, roi_height], pen=pg.mkPen("y", width=2), + removable=True) + roi.sigRemoveRequested.connect(self.remove_roi) + roi.sigRegionChangeFinished.connect(self.roi_changed) + self.p0.addItem(roi) + self.remove_roi_obj = roi + self.roi_changed(roi) + + def delete_multiple_cells(self): + self.unselect_cell() + self.disable_buttons_removeROIs() + self.DoneDeleteMultipleROIButton.setEnabled(True) + self.MakeDeletionRegionButton.setEnabled(True) + self.CancelDeleteMultipleROIButton.setEnabled(True) + self.deleting_multiple = True + + def done_remove_multiple_cells(self): + self.deleting_multiple = False + self.removing_region = False + self.DoneDeleteMultipleROIButton.setEnabled(False) + self.MakeDeletionRegionButton.setEnabled(False) + self.CancelDeleteMultipleROIButton.setEnabled(False) + + if self.removing_cells_list: + self.removing_cells_list = list(set(self.removing_cells_list)) + display_remove_list = [i - 1 for i in self.removing_cells_list] + print(f"GUI_INFO: removing cells: {display_remove_list}") + self.remove_cell(self.removing_cells_list) + self.removing_cells_list.clear() + self.unselect_cell() + self.enable_buttons() + + if self.remove_roi_obj is not None: + self.remove_roi(self.remove_roi_obj) + + def merge_cells(self, idx): + self.prev_selected = self.selected + self.selected = idx + if self.selected != self.prev_selected: + for z in range(self.NZ): + ar0, ac0 = np.nonzero(self.cellpix[z] == self.prev_selected) + ar1, ac1 = np.nonzero(self.cellpix[z] == self.selected) + touching = np.logical_and((ar0[:, np.newaxis] - ar1) < 3, + (ac0[:, np.newaxis] - ac1) < 3).sum() + ar = np.hstack((ar0, ar1)) + ac = np.hstack((ac0, ac1)) + vr0, vc0 = np.nonzero(self.outpix[z] == self.prev_selected) + vr1, vc1 = np.nonzero(self.outpix[z] == self.selected) + self.outpix[z, vr0, vc0] = 0 + self.outpix[z, vr1, vc1] = 0 + if touching > 0: + mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), np.uint8) + mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1 + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0].squeeze().T + vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2 + + else: + vr = np.hstack((vr0, vr1)) + vc = np.hstack((vc0, vc1)) + color = self.cellcolors[self.prev_selected] + self.draw_mask(z, ar, ac, vr, vc, color, idx=self.prev_selected) + self.remove_cell(self.selected) + print("GUI_INFO: merged two cells") + self.update_layer() + io._save_sets_with_check(self) + self.undo.setEnabled(False) + self.redo.setEnabled(False) + + def undo_remove_cell(self): + if len(self.removed_cell) > 0: + z = 0 + ar, ac = self.removed_cell[2] + vr, vc = self.removed_cell[3] + color = self.removed_cell[1] + self.draw_mask(z, ar, ac, vr, vc, color) + self.toggle_mask_ops() + self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], axis=0) + self.ncells += 1 + self.ismanual = np.append(self.ismanual, self.removed_cell[0]) + self.zdraw.append([]) + print(">>> added back removed cell") + self.update_layer() + io._save_sets_with_check(self) + self.removed_cell = [] + self.redo.setEnabled(False) + + def remove_stroke(self, delete_points=True, stroke_ind=-1): + stroke = np.array(self.strokes[stroke_ind]) + cZ = self.currentZ + inZ = stroke[0, 0] == cZ + if inZ: + outpix = self.outpix[cZ, stroke[:, 1], stroke[:, 2]] > 0 + self.layerz[stroke[~outpix, 1], stroke[~outpix, 2]] = np.array([0, 0, 0, 0]) + cellpix = self.cellpix[cZ, stroke[:, 1], stroke[:, 2]] + ccol = self.cellcolors.copy() + if self.selected > 0: + ccol[self.selected] = np.array([255, 255, 255]) + col2mask = ccol[cellpix] + if self.masksOn: + col2mask = np.concatenate( + (col2mask, self.opacity * (cellpix[:, np.newaxis] > 0)), axis=-1) + else: + col2mask = np.concatenate((col2mask, 0 * (cellpix[:, np.newaxis] > 0)), + axis=-1) + self.layerz[stroke[:, 1], stroke[:, 2], :] = col2mask + if self.outlinesOn: + self.layerz[stroke[outpix, 1], stroke[outpix, + 2]] = np.array(self.outcolor) + if delete_points: + del self.current_point_set[stroke_ind] + self.update_layer() + + del self.strokes[stroke_ind] + + def plot_clicked(self, event): + if event.button()==QtCore.Qt.LeftButton \ + and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\ + and not self.removing_region: + if event.double(): + try: + self.p0.setYRange(0, self.Ly + self.pr) + except: + self.p0.setYRange(0, self.Ly) + self.p0.setXRange(0, self.Lx) + + def cancel_remove_multiple(self): + self.clear_multi_selected_cells() + self.done_remove_multiple_cells() + + def clear_multi_selected_cells(self): + # unselect all previously selected cells: + for idx in self.removing_cells_list: + self.unselect_cell_multi(idx) + self.removing_cells_list.clear() + + def add_roi(self, roi): + self.p0.addItem(roi) + self.remove_roi_obj = roi + + def remove_roi(self, roi): + self.clear_multi_selected_cells() + assert roi == self.remove_roi_obj + self.remove_roi_obj = None + self.p0.removeItem(roi) + self.removing_region = False + + def roi_changed(self, roi): + # find the overlapping cells and make them selected + pos = roi.pos() + size = roi.size() + x0 = int(pos.x()) + y0 = int(pos.y()) + x1 = int(pos.x() + size.x()) + y1 = int(pos.y() + size.y()) + if x0 < 0: + x0 = 0 + if y0 < 0: + y0 = 0 + if x1 > self.Lx: + x1 = self.Lx + if y1 > self.Ly: + y1 = self.Ly + + # find cells in that region + cell_idxs = np.unique(self.cellpix[self.currentZ, y0:y1, x0:x1]) + cell_idxs = np.trim_zeros(cell_idxs) + # deselect cells not in region by deselecting all and then selecting the ones in the region + self.clear_multi_selected_cells() + + for idx in cell_idxs: + self.select_cell_multi(idx) + self.removing_cells_list.append(idx) + + self.update_layer() + + def mouse_moved(self, pos): + items = self.win.scene().items(pos) + + def color_choose(self): + self.color = self.RGBDropDown.currentIndex() + self.view = 0 + self.ViewDropDown.setCurrentIndex(self.view) + self.update_plot() + + def update_plot(self): + self.view = self.ViewDropDown.currentIndex() + self.Ly, self.Lx, _ = self.stack[self.currentZ].shape + + if self.view == 0 or self.view == self.ViewDropDown.count() - 1: + image = self.stack[ + self.currentZ] if self.view == 0 else self.stack_filtered[self.currentZ] + if self.color == 0: + self.img.setImage(image, autoLevels=False, lut=None) + if self.nchan > 1: + levels = np.array([ + self.saturation[0][self.currentZ], + self.saturation[1][self.currentZ], + self.saturation[2][self.currentZ] + ]) + self.img.setLevels(levels) + else: + self.img.setLevels(self.saturation[0][self.currentZ]) + elif self.color > 0 and self.color < 4: + if self.nchan > 1: + image = image[:, :, self.color - 1] + self.img.setImage(image, autoLevels=False, lut=self.cmap[self.color]) + if self.nchan > 1: + self.img.setLevels(self.saturation[self.color - 1][self.currentZ]) + else: + self.img.setLevels(self.saturation[0][self.currentZ]) + elif self.color == 4: + if self.nchan > 1: + image = image.mean(axis=-1) + self.img.setImage(image, autoLevels=False, lut=None) + self.img.setLevels(self.saturation[0][self.currentZ]) + elif self.color == 5: + if self.nchan > 1: + image = image.mean(axis=-1) + self.img.setImage(image, autoLevels=False, lut=self.cmap[0]) + self.img.setLevels(self.saturation[0][self.currentZ]) + else: + image = np.zeros((self.Ly, self.Lx), np.uint8) + if len(self.flows) >= self.view - 1 and len(self.flows[self.view - 1]) > 0: + image = self.flows[self.view - 1][self.currentZ] + if self.view > 1: + self.img.setImage(image, autoLevels=False, lut=self.bwr) + else: + self.img.setImage(image, autoLevels=False, lut=None) + self.img.setLevels([0.0, 255.0]) + + for r in range(3): + self.sliders[r].setValue([ + self.saturation[r][self.currentZ][0], + self.saturation[r][self.currentZ][1] + ]) + self.win.show() + self.show() + + + def update_layer(self): + if self.masksOn or self.outlinesOn: + self.layer.setImage(self.layerz, autoLevels=False) + self.win.show() + self.show() + + + def add_set(self): + if len(self.current_point_set) > 0: + while len(self.strokes) > 0: + self.remove_stroke(delete_points=False) + if len(self.current_point_set[0]) > 8: + color = self.colormap[self.ncells.get(), :3] + median = self.add_mask(points=self.current_point_set, color=color) + if median is not None: + self.removed_cell = [] + self.toggle_mask_ops() + self.cellcolors = np.append(self.cellcolors, color[np.newaxis, :], + axis=0) + self.ncells += 1 + self.ismanual = np.append(self.ismanual, True) + if self.NZ == 1: + # only save after each cell if single image + io._save_sets_with_check(self) + else: + print("GUI_ERROR: cell too small, not drawn") + self.current_stroke = [] + self.strokes = [] + self.current_point_set = [] + self.update_layer() + + def add_mask(self, points=None, color=(100, 200, 50), dense=True): + # points is list of strokes + points_all = np.concatenate(points, axis=0) + + # loop over z values + median = [] + zdraw = np.unique(points_all[:, 0]) + z = 0 + ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros( + 0, "int"), np.zeros(0, "int") + for stroke in points: + stroke = np.concatenate(stroke, axis=0).reshape(-1, 4) + vr = stroke[:, 1] + vc = stroke[:, 2] + # get points inside drawn points + mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8) + pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2), + axis=-1)[:, np.newaxis, :] + mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) + ar, ac = np.nonzero(mask) + ar, ac = ar + vr.min() - 2, ac + vc.min() - 2 + # get dense outline + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0][:,0].T + vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2 + # concatenate all points + ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac)))) + # if these pixels are overlapping with another cell, reassign them + ioverlap = self.cellpix[z][ar, ac] > 0 + if (~ioverlap).sum() < 10: + print("GUI_ERROR: cell < 10 pixels without overlaps, not drawn") + return None + elif ioverlap.sum() > 0: + ar, ac = ar[~ioverlap], ac[~ioverlap] + # compute outline of new mask + mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), np.uint8) + mask[ar - vr.min() + 2, ac - vc.min() + 2] = 1 + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0][:,0].T + vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2 + ars = np.concatenate((ars, ar), axis=0) + acs = np.concatenate((acs, ac), axis=0) + vrs = np.concatenate((vrs, vr), axis=0) + vcs = np.concatenate((vcs, vc), axis=0) + + self.draw_mask(z, ars, acs, vrs, vcs, color) + median.append(np.array([np.median(ars), np.median(acs)])) + + self.zdraw.append(zdraw) + d = datetime.datetime.now() + self.track_changes.append( + [d.strftime("%m/%d/%Y, %H:%M:%S"), "added mask", [ar, ac]]) + return median + + def draw_mask(self, z, ar, ac, vr, vc, color, idx=None): + """ draw single mask using outlines and area """ + if idx is None: + idx = self.ncells + 1 + self.cellpix[z, vr, vc] = idx + self.cellpix[z, ar, ac] = idx + self.outpix[z, vr, vc] = idx + if self.restore and "upsample" in self.restore: + if self.resize: + self.cellpix_resize[z, vr, vc] = idx + self.cellpix_resize[z, ar, ac] = idx + self.outpix_resize[z, vr, vc] = idx + self.cellpix_orig[z, (vr / self.ratio).astype(int), + (vc / self.ratio).astype(int)] = idx + self.cellpix_orig[z, (ar / self.ratio).astype(int), + (ac / self.ratio).astype(int)] = idx + self.outpix_orig[z, (vr / self.ratio).astype(int), + (vc / self.ratio).astype(int)] = idx + else: + self.cellpix_orig[z, vr, vc] = idx + self.cellpix_orig[z, ar, ac] = idx + self.outpix_orig[z, vr, vc] = idx + + # get upsampled mask + vrr = (vr.copy() * self.ratio).astype(int) + vcr = (vc.copy() * self.ratio).astype(int) + mask = np.zeros((np.ptp(vrr) + 4, np.ptp(vcr) + 4), np.uint8) + pts = np.stack((vcr - vcr.min() + 2, vrr - vrr.min() + 2), + axis=-1)[:, np.newaxis, :] + mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) + arr, acr = np.nonzero(mask) + arr, acr = arr + vrr.min() - 2, acr + vcr.min() - 2 + # get dense outline + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0].squeeze().T + vrr, vcr = pvr + vrr.min() - 2, pvc + vcr.min() - 2 + # concatenate all points + arr, acr = np.hstack((np.vstack((vrr, vcr)), np.vstack((arr, acr)))) + self.cellpix_resize[z, vrr, vcr] = idx + self.cellpix_resize[z, arr, acr] = idx + self.outpix_resize[z, vrr, vcr] = idx + + if z == self.currentZ: + self.layerz[ar, ac, :3] = color + if self.masksOn: + self.layerz[ar, ac, -1] = self.opacity + if self.outlinesOn: + self.layerz[vr, vc] = np.array(self.outcolor) + + def compute_scale(self): + # get diameter from gui + diameter = self.segmentation_settings.diameter + if not diameter: + diameter = 30 + + self.pr = int(diameter) + self.radii_padding = int(self.pr * 1.25) + self.radii = np.zeros((self.Ly + self.radii_padding, self.Lx, 4), np.uint8) + yy, xx = disk([self.Ly + self.radii_padding / 2 - 1, self.pr / 2 + 1], + self.pr / 2, self.Ly + self.radii_padding, self.Lx) + # rgb(150,50,150) + self.radii[yy, xx, 0] = 150 + self.radii[yy, xx, 1] = 50 + self.radii[yy, xx, 2] = 150 + self.radii[yy, xx, 3] = 255 + self.p0.setYRange(0, self.Ly + self.radii_padding) + self.p0.setXRange(0, self.Lx) + + def update_scale(self): + self.compute_scale() + self.scale.setImage(self.radii, autoLevels=False) + self.scale.setLevels([0.0, 255.0]) + self.win.show() + self.show() + + + def draw_layer(self): + if self.resize: + self.Ly, self.Lx = self.Lyr, self.Lxr + else: + self.Ly, self.Lx = self.Ly0, self.Lx0 + + if self.masksOn or self.outlinesOn: + if self.restore and "upsample" in self.restore: + if self.resize: + self.cellpix = self.cellpix_resize.copy() + self.outpix = self.outpix_resize.copy() + else: + self.cellpix = self.cellpix_orig.copy() + self.outpix = self.outpix_orig.copy() + + self.layerz = np.zeros((self.Ly, self.Lx, 4), np.uint8) + if self.masksOn: + self.layerz[..., :3] = self.cellcolors[self.cellpix[self.currentZ], :] + self.layerz[..., 3] = self.opacity * (self.cellpix[self.currentZ] + > 0).astype(np.uint8) + if self.selected > 0: + self.layerz[self.cellpix[self.currentZ] == self.selected] = np.array( + [255, 255, 255, self.opacity]) + cZ = self.currentZ + stroke_z = np.array([s[0][0] for s in self.strokes]) + inZ = np.nonzero(stroke_z == cZ)[0] + if len(inZ) > 0: + for i in inZ: + stroke = np.array(self.strokes[i]) + self.layerz[stroke[:, 1], stroke[:, + 2]] = np.array([255, 0, 255, 100]) + else: + self.layerz[..., 3] = 0 + + if self.outlinesOn: + self.layerz[self.outpix[self.currentZ] > 0] = np.array( + self.outcolor).astype(np.uint8) + + + def set_normalize_params(self, normalize_params): + from cellpose.models import normalize_default + if self.restore != "filter": + keys = list(normalize_params.keys()).copy() + for key in keys: + if key != "percentile": + normalize_params[key] = normalize_default[key] + normalize_params = {**normalize_default, **normalize_params} + out = self.check_filter_params(normalize_params["sharpen_radius"], + normalize_params["smooth_radius"], + normalize_params["tile_norm_blocksize"], + normalize_params["tile_norm_smooth3D"], + normalize_params["norm3D"], + normalize_params["invert"]) + + + def check_filter_params(self, sharpen, smooth, tile_norm, smooth3D, norm3D, invert): + tile_norm = 0 if tile_norm < 0 else tile_norm + sharpen = 0 if sharpen < 0 else sharpen + smooth = 0 if smooth < 0 else smooth + smooth3D = 0 if smooth3D < 0 else smooth3D + norm3D = bool(norm3D) + invert = bool(invert) + if tile_norm > self.Ly and tile_norm > self.Lx: + print( + "GUI_ERROR: tile size (tile_norm) bigger than both image dimensions, disabling" + ) + tile_norm = 0 + self.filt_edits[0].setText(str(sharpen)) + self.filt_edits[1].setText(str(smooth)) + self.filt_edits[2].setText(str(tile_norm)) + self.filt_edits[3].setText(str(smooth3D)) + self.norm3D_cb.setChecked(norm3D) + return sharpen, smooth, tile_norm, smooth3D, norm3D, invert + + def get_normalize_params(self): + percentile = [ + self.segmentation_settings.low_percentile, + self.segmentation_settings.high_percentile, + ] + normalize_params = {"percentile": percentile} + norm3D = self.norm3D_cb.isChecked() + normalize_params["norm3D"] = norm3D + sharpen = float(self.filt_edits[0].text()) + smooth = float(self.filt_edits[1].text()) + tile_norm = float(self.filt_edits[2].text()) + smooth3D = float(self.filt_edits[3].text()) + invert = False + out = self.check_filter_params(sharpen, smooth, tile_norm, smooth3D, norm3D, + invert) + sharpen, smooth, tile_norm, smooth3D, norm3D, invert = out + normalize_params["sharpen_radius"] = sharpen + normalize_params["smooth_radius"] = smooth + normalize_params["tile_norm_blocksize"] = tile_norm + normalize_params["tile_norm_smooth3D"] = smooth3D + normalize_params["invert"] = invert + + from cellpose.models import normalize_default + normalize_params = {**normalize_default, **normalize_params} + + return normalize_params + + def compute_saturation_if_checked(self): + if self.autobtn.isChecked(): + self.compute_saturation() + + def compute_saturation(self, return_img=False): + norm = self.get_normalize_params() + print(norm) + sharpen, smooth = norm["sharpen_radius"], norm["smooth_radius"] + percentile = norm["percentile"] + tile_norm = norm["tile_norm_blocksize"] + invert = norm["invert"] + norm3D = norm["norm3D"] + smooth3D = norm["tile_norm_smooth3D"] + tile_norm = norm["tile_norm_blocksize"] + + if sharpen > 0 or smooth > 0 or tile_norm > 0: + img_norm = self.stack.copy() + else: + img_norm = self.stack + + if sharpen > 0 or smooth > 0 or tile_norm > 0: + self.restore = "filter" + print( + "GUI_INFO: computing filtered image because sharpen > 0 or tile_norm > 0" + ) + print( + "GUI_WARNING: will use memory to create filtered image -- make sure to have RAM for this" + ) + img_norm = self.stack.copy() + if sharpen > 0 or smooth > 0: + img_norm = smooth_sharpen_img(self.stack, sharpen_radius=sharpen, + smooth_radius=smooth) + + if tile_norm > 0: + img_norm = normalize99_tile(img_norm, blocksize=tile_norm, + lower=percentile[0], upper=percentile[1], + smooth3D=smooth3D, norm3D=norm3D) + # convert to 0->255 + img_norm_min = img_norm.min() + img_norm_max = img_norm.max() + for c in range(img_norm.shape[-1]): + if np.ptp(img_norm[..., c]) > 1e-3: + img_norm[..., c] -= img_norm_min + img_norm[..., c] /= (img_norm_max - img_norm_min) + img_norm *= 255 + self.stack_filtered = img_norm + self.ViewDropDown.model().item(self.ViewDropDown.count() - + 1).setEnabled(True) + self.ViewDropDown.setCurrentIndex(self.ViewDropDown.count() - 1) + else: + img_norm = self.stack if self.restore is None or self.restore == "filter" else self.stack_filtered + + if self.autobtn.isChecked(): + self.saturation = [] + for c in range(img_norm.shape[-1]): + self.saturation.append([]) + if np.ptp(img_norm[..., c]) > 1e-3: + if norm3D: + x01 = np.percentile(img_norm[..., c], percentile[0]) + x99 = np.percentile(img_norm[..., c], percentile[1]) + if invert: + x01i = 255. - x99 + x99i = 255. - x01 + x01, x99 = x01i, x99i + for n in range(self.NZ): + self.saturation[-1].append([x01, x99]) + else: + for z in range(self.NZ): + if self.NZ > 1: + x01 = np.percentile(img_norm[z, :, :, c], percentile[0]) + x99 = np.percentile(img_norm[z, :, :, c], percentile[1]) + else: + x01 = np.percentile(img_norm[..., c], percentile[0]) + x99 = np.percentile(img_norm[..., c], percentile[1]) + if invert: + x01i = 255. - x99 + x99i = 255. - x01 + x01, x99 = x01i, x99i + self.saturation[-1].append([x01, x99]) + else: + for n in range(self.NZ): + self.saturation[-1].append([0, 255.]) + print(self.saturation[2][self.currentZ]) + + if img_norm.shape[-1] == 1: + self.saturation.append(self.saturation[0]) + self.saturation.append(self.saturation[0]) + + # self.autobtn.setChecked(True) + self.update_plot() + + + def get_model_path(self, custom=False): + if custom: + self.current_model = self.ModelChooseC.currentText() + self.current_model_path = os.fspath( + models.MODEL_DIR.joinpath(self.current_model)) + else: + self.current_model = "cpsam" + self.current_model_path = models.model_path(self.current_model) + + def initialize_model(self, model_name=None, custom=False): + if model_name is None or custom: + self.get_model_path(custom=custom) + if not os.path.exists(self.current_model_path): + raise ValueError("need to specify model (use dropdown)") + + if model_name is None or not isinstance(model_name, str): + self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), + pretrained_model=self.current_model_path) + else: + self.current_model = model_name + self.current_model_path = os.fspath( + models.MODEL_DIR.joinpath(self.current_model)) + + self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), + pretrained_model=self.current_model) + + def add_model(self): + io._add_model(self) + return + + def remove_model(self): + io._remove_model(self) + return + + def new_model(self): + if self.NZ != 1: + print("ERROR: cannot train model on 3D data") + return + + # train model + image_names = self.get_files()[0] + self.train_data, self.train_labels, self.train_files, restore, normalize_params = io._get_train_set( + image_names) + TW = guiparts.TrainWindow(self, models.MODEL_NAMES) + train = TW.exec_() + if train: + self.logger.info( + f"training with {[os.path.split(f)[1] for f in self.train_files]}") + self.train_model(restore=restore, normalize_params=normalize_params) + else: + print("GUI_INFO: training cancelled") + + def train_model(self, restore=None, normalize_params=None): + from cellpose.models import normalize_default + if normalize_params is None: + normalize_params = copy.deepcopy(normalize_default) + model_type = models.MODEL_NAMES[self.training_params["model_index"]] + self.logger.info(f"training new model starting at model {model_type}") + self.current_model = model_type + + self.model = models.CellposeModel(gpu=self.useGPU.isChecked(), + model_type=model_type) + save_path = os.path.dirname(self.filename) + + print("GUI_INFO: name of new model: " + self.training_params["model_name"]) + self.new_model_path, train_losses = train.train_seg( + self.model.net, train_data=self.train_data, train_labels=self.train_labels, + normalize=normalize_params, min_train_masks=0, + save_path=save_path, nimg_per_epoch=max(2, len(self.train_data)), + learning_rate=self.training_params["learning_rate"], + weight_decay=self.training_params["weight_decay"], + n_epochs=self.training_params["n_epochs"], + model_name=self.training_params["model_name"])[:2] + # save train losses + np.save(str(self.new_model_path) + "_train_losses.npy", train_losses) + # run model on next image + io._add_model(self, self.new_model_path) + diam_labels = self.model.net.diam_labels.item() #.copy() + self.new_model_ind = len(self.model_strings) + self.autorun = True + self.clear_all() + self.restore = restore + self.set_normalize_params(normalize_params) + self.get_next_image(load_seg=False) + + self.compute_segmentation(custom=True) + self.logger.info( + f"!!! computed masks for {os.path.split(self.filename)[1]} from new model !!!" + ) + + + def compute_cprob(self): + if self.recompute_masks: + flow_threshold = self.segmentation_settings.flow_threshold + cellprob_threshold = self.segmentation_settings.cellprob_threshold + niter = self.segmentation_settings.niter + min_size = int(self.min_size.text()) if not isinstance( + self.min_size, int) else self.min_size + + self.logger.info( + "computing masks with cell prob=%0.3f, flow error threshold=%0.3f" % + (cellprob_threshold, flow_threshold)) + + try: + dP = self.flows[2].squeeze() + cellprob = self.flows[3].squeeze() + except IndexError: + self.logger.error("Flows don't exist, try running model again.") + return + + maski = dynamics.resize_and_compute_masks( + dP=dP, + cellprob=cellprob, + niter=niter, + do_3D=self.load_3D, + min_size=min_size, + # max_size_fraction=min_size_fraction, # Leave as default + cellprob_threshold=cellprob_threshold, + flow_threshold=flow_threshold) + + self.masksOn = True + if not self.OCheckBox.isChecked(): + self.MCheckBox.setChecked(True) + if maski.ndim < 3: + maski = maski[np.newaxis, ...] + self.logger.info("%d cells found" % (len(np.unique(maski)[1:]))) + io._masks_to_gui(self, maski, outlines=None) + self.show() + + + def compute_segmentation(self, custom=False, model_name=None, load_model=True): + self.progress.setValue(0) + try: + tic = time.time() + self.clear_all() + self.flows = [[], [], []] + if load_model: + self.initialize_model(model_name=model_name, custom=custom) + self.progress.setValue(10) + do_3D = self.load_3D + stitch_threshold = float(self.stitch_threshold.text()) if not isinstance( + self.stitch_threshold, float) else self.stitch_threshold + anisotropy = float(self.anisotropy.text()) if not isinstance( + self.anisotropy, float) else self.anisotropy + flow3D_smooth = float(self.flow3D_smooth.text()) if not isinstance( + self.flow3D_smooth, float) else self.flow3D_smooth + min_size = int(self.min_size.text()) if not isinstance( + self.min_size, int) else self.min_size + + do_3D = False if stitch_threshold > 0. else do_3D + + if self.restore == "filter": + data = self.stack_filtered.copy().squeeze() + else: + data = self.stack.copy().squeeze() + + flow_threshold = self.segmentation_settings.flow_threshold + cellprob_threshold = self.segmentation_settings.cellprob_threshold + diameter = self.segmentation_settings.diameter + niter = self.segmentation_settings.niter + + normalize_params = self.get_normalize_params() + print(normalize_params) + try: + masks, flows = self.model.eval( + data, + diameter=diameter, + cellprob_threshold=cellprob_threshold, + flow_threshold=flow_threshold, do_3D=do_3D, niter=niter, + normalize=normalize_params, stitch_threshold=stitch_threshold, + anisotropy=anisotropy, flow3D_smooth=flow3D_smooth, + min_size=min_size, channel_axis=-1, + progress=self.progress, z_axis=0 if self.NZ > 1 else None)[:2] + except Exception as e: + print("NET ERROR: %s" % e) + self.progress.setValue(0) + return + + self.progress.setValue(75) + + # convert flows to uint8 and resize to original image size + flows_new = [] + flows_new.append(flows[0].copy()) # RGB flow + flows_new.append((np.clip(normalize99(flows[2].copy()), 0, 1) * + 255).astype("uint8")) # cellprob + flows_new.append(flows[1].copy()) # XY flows + flows_new.append(flows[2].copy()) # original cellprob + + if self.load_3D: + if stitch_threshold == 0.: + flows_new.append((flows[1][0] / 10 * 127 + 127).astype("uint8")) + else: + flows_new.append(np.zeros(flows[1][0].shape, dtype="uint8")) + + if not self.load_3D: + if self.restore and "upsample" in self.restore: + self.Ly, self.Lx = self.Lyr, self.Lxr + + if flows_new[0].shape[-3:-1] != (self.Ly, self.Lx): + self.flows = [] + for j in range(len(flows_new)): + self.flows.append( + resize_image(flows_new[j], Ly=self.Ly, Lx=self.Lx, + interpolation=cv2.INTER_NEAREST)) + else: + self.flows = flows_new + else: + self.flows = [] + Lz, Ly, Lx = self.NZ, self.Ly, self.Lx + Lz0, Ly0, Lx0 = flows_new[0].shape[:3] + print("GUI_INFO: resizing flows to original image size") + for j in range(len(flows_new)): + flow0 = flows_new[j] + if Ly0 != Ly: + flow0 = resize_image(flow0, Ly=Ly, Lx=Lx, + no_channels=flow0.ndim==3, + interpolation=cv2.INTER_NEAREST) + if Lz0 != Lz: + flow0 = np.swapaxes(resize_image(np.swapaxes(flow0, 0, 1), + Ly=Lz, Lx=Lx, + no_channels=flow0.ndim==3, + interpolation=cv2.INTER_NEAREST), 0, 1) + self.flows.append(flow0) + + # add first axis + if self.NZ == 1: + masks = masks[np.newaxis, ...] + self.flows = [ + self.flows[n][np.newaxis, ...] for n in range(len(self.flows)) + ] + + self.logger.info("%d cells found with model in %0.3f sec" % + (len(np.unique(masks)[1:]), time.time() - tic)) + self.progress.setValue(80) + z = 0 + + io._masks_to_gui(self, masks, outlines=None) + self.masksOn = True + self.MCheckBox.setChecked(True) + self.progress.setValue(100) + if self.restore != "filter" and self.restore is not None and self.autobtn.isChecked(): + self.compute_saturation() + if not do_3D and not stitch_threshold > 0: + self.recompute_masks = True + else: + self.recompute_masks = False + except Exception as e: + print("ERROR: %s" % e) diff --git a/models/seg_post_model/cellpose/gui/gui3d.py b/models/seg_post_model/cellpose/gui/gui3d.py new file mode 100644 index 0000000000000000000000000000000000000000..c72665750247fedd20293a3d6080b47cb6e8be06 --- /dev/null +++ b/models/seg_post_model/cellpose/gui/gui3d.py @@ -0,0 +1,667 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu. +""" + +import sys, pathlib, warnings + +from qtpy import QtGui, QtCore +from qtpy.QtWidgets import QApplication, QScrollBar, QCheckBox, QLabel, QLineEdit +import pyqtgraph as pg + +import numpy as np +from scipy.stats import mode +import cv2 + +from . import guiparts, io +from ..utils import download_url_to_file, masks_to_outlines +from .gui import MainW + +try: + import matplotlib.pyplot as plt + MATPLOTLIB = True +except: + MATPLOTLIB = False + + +def avg3d(C): + """ smooth value of c across nearby points + (c is center of grid directly below point) + b -- a -- b + a -- c -- a + b -- a -- b + """ + Ly, Lx = C.shape + # pad T by 2 + T = np.zeros((Ly + 2, Lx + 2), "float32") + M = np.zeros((Ly, Lx), "float32") + T[1:-1, 1:-1] = C.copy() + y, x = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int), + indexing="ij") + y += 1 + x += 1 + a = 1. / 2 #/(z**2 + 1)**0.5 + b = 1. / (1 + 2**0.5) #(z**2 + 2)**0.5 + c = 1. + M = (b * T[y - 1, x - 1] + a * T[y - 1, x] + b * T[y - 1, x + 1] + a * T[y, x - 1] + + c * T[y, x] + a * T[y, x + 1] + b * T[y + 1, x - 1] + a * T[y + 1, x] + + b * T[y + 1, x + 1]) + M /= 4 * a + 4 * b + c + return M + + +def interpZ(mask, zdraw): + """ find nearby planes and average their values using grid of points + zfill is in ascending order + """ + ifill = np.ones(mask.shape[0], "bool") + zall = np.arange(0, mask.shape[0], 1, int) + ifill[zdraw] = False + zfill = zall[ifill] + zlower = zdraw[np.searchsorted(zdraw, zfill, side="left") - 1] + zupper = zdraw[np.searchsorted(zdraw, zfill, side="right")] + for k, z in enumerate(zfill): + Z = zupper[k] - zlower[k] + zl = (z - zlower[k]) / Z + plower = avg3d(mask[zlower[k]]) * (1 - zl) + pupper = avg3d(mask[zupper[k]]) * zl + mask[z] = (plower + pupper) > 0.33 + return mask, zfill + + +def run(image=None): + from ..io import logger_setup + logger, log_file = logger_setup() + # Always start by initializing Qt (only once per application) + warnings.filterwarnings("ignore") + app = QApplication(sys.argv) + icon_path = pathlib.Path.home().joinpath(".cellpose", "logo.png") + guip_path = pathlib.Path.home().joinpath(".cellpose", "cellpose_gui.png") + style_path = pathlib.Path.home().joinpath(".cellpose", "style_choice.npy") + if not icon_path.is_file(): + cp_dir = pathlib.Path.home().joinpath(".cellpose") + cp_dir.mkdir(exist_ok=True) + print("downloading logo") + download_url_to_file( + "https://www.cellpose.org/static/images/cellpose_transparent.png", + icon_path, progress=True) + if not guip_path.is_file(): + print("downloading help window image") + download_url_to_file("https://www.cellpose.org/static/images/cellpose_gui.png", + guip_path, progress=True) + icon_path = str(icon_path.resolve()) + app_icon = QtGui.QIcon() + app_icon.addFile(icon_path, QtCore.QSize(16, 16)) + app_icon.addFile(icon_path, QtCore.QSize(24, 24)) + app_icon.addFile(icon_path, QtCore.QSize(32, 32)) + app_icon.addFile(icon_path, QtCore.QSize(48, 48)) + app_icon.addFile(icon_path, QtCore.QSize(64, 64)) + app_icon.addFile(icon_path, QtCore.QSize(256, 256)) + app.setWindowIcon(app_icon) + app.setStyle("Fusion") + app.setPalette(guiparts.DarkPalette()) + MainW_3d(image=image, logger=logger) + ret = app.exec_() + sys.exit(ret) + + +class MainW_3d(MainW): + + def __init__(self, image=None, logger=None): + # MainW init + MainW.__init__(self, image=image, logger=logger) + + # add gradZ view + self.ViewDropDown.insertItem(3, "gradZ") + + # turn off single stroke + self.SCheckBox.setChecked(False) + + ### add orthoviews and z-bar + # ortho crosshair lines + self.vLine = pg.InfiniteLine(angle=90, movable=False) + self.hLine = pg.InfiniteLine(angle=0, movable=False) + self.vLineOrtho = [ + pg.InfiniteLine(angle=90, movable=False), + pg.InfiniteLine(angle=90, movable=False) + ] + self.hLineOrtho = [ + pg.InfiniteLine(angle=0, movable=False), + pg.InfiniteLine(angle=0, movable=False) + ] + self.make_orthoviews() + + # z scrollbar underneath + self.scroll = QScrollBar(QtCore.Qt.Horizontal) + self.scroll.setMaximum(10) + self.scroll.valueChanged.connect(self.move_in_Z) + self.lmain.addWidget(self.scroll, 40, 9, 1, 30) + + b = 22 + + label = QLabel("stitch\nthreshold:") + label.setToolTip( + "for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)" + ) + label.setFont(self.medfont) + self.segBoxG.addWidget(label, b, 0, 1, 4) + self.stitch_threshold = QLineEdit() + self.stitch_threshold.setText("0.0") + self.stitch_threshold.setFixedWidth(30) + self.stitch_threshold.setFont(self.medfont) + self.stitch_threshold.setToolTip( + "for 3D volumes, turn on stitch_threshold to stitch masks across planes instead of running cellpose in 3D (see docs for details)" + ) + self.segBoxG.addWidget(self.stitch_threshold, b, 3, 1, 1) + + label = QLabel("flow3D\nsmooth:") + label.setToolTip( + "for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)" + ) + label.setFont(self.medfont) + self.segBoxG.addWidget(label, b, 4, 1, 3) + self.flow3D_smooth = QLineEdit() + self.flow3D_smooth.setText("0.0") + self.flow3D_smooth.setFixedWidth(30) + self.flow3D_smooth.setFont(self.medfont) + self.flow3D_smooth.setToolTip( + "for 3D volumes, smooth flows by a Gaussian with standard deviation flow3D_smooth (see docs for details)" + ) + self.segBoxG.addWidget(self.flow3D_smooth, b, 7, 1, 1) + + b+=1 + label = QLabel("anisotropy:") + label.setToolTip( + "for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)" + ) + label.setFont(self.medfont) + self.segBoxG.addWidget(label, b, 0, 1, 3) + self.anisotropy = QLineEdit() + self.anisotropy.setText("1.0") + self.anisotropy.setFixedWidth(30) + self.anisotropy.setFont(self.medfont) + self.anisotropy.setToolTip( + "for 3D volumes, increase in sampling in Z vs XY as a ratio, e.g. set set to 2.0 if Z is sampled half as dense as X or Y (see docs for details)" + ) + self.segBoxG.addWidget(self.anisotropy, b, 3, 1, 1) + + b+=1 + label = QLabel("min\nsize:") + label.setToolTip( + "all masks less than this size in pixels (volume) will be removed" + ) + label.setFont(self.medfont) + self.segBoxG.addWidget(label, b, 0, 1, 4) + self.min_size = QLineEdit() + self.min_size.setText("15") + self.min_size.setFixedWidth(50) + self.min_size.setFont(self.medfont) + self.min_size.setToolTip( + "all masks less than this size in pixels (volume) will be removed" + ) + self.segBoxG.addWidget(self.min_size, b, 3, 1, 1) + + b += 1 + self.orthobtn = QCheckBox("ortho") + self.orthobtn.setToolTip("activate orthoviews with 3D image") + self.orthobtn.setFont(self.medfont) + self.orthobtn.setChecked(False) + self.l0.addWidget(self.orthobtn, b, 0, 1, 2) + self.orthobtn.toggled.connect(self.toggle_ortho) + + label = QLabel("dz:") + label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + label.setFont(self.medfont) + self.l0.addWidget(label, b, 2, 1, 1) + self.dz = 10 + self.dzedit = QLineEdit() + self.dzedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.dzedit.setText(str(self.dz)) + self.dzedit.returnPressed.connect(self.update_ortho) + self.dzedit.setFixedWidth(40) + self.dzedit.setFont(self.medfont) + self.l0.addWidget(self.dzedit, b, 3, 1, 2) + + label = QLabel("z-aspect:") + label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + label.setFont(self.medfont) + self.l0.addWidget(label, b, 5, 1, 2) + self.zaspect = 1.0 + self.zaspectedit = QLineEdit() + self.zaspectedit.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.zaspectedit.setText(str(self.zaspect)) + self.zaspectedit.returnPressed.connect(self.update_ortho) + self.zaspectedit.setFixedWidth(40) + self.zaspectedit.setFont(self.medfont) + self.l0.addWidget(self.zaspectedit, b, 7, 1, 2) + + b += 1 + # add z position underneath + self.currentZ = 0 + label = QLabel("Z:") + label.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.l0.addWidget(label, b, 5, 1, 2) + self.zpos = QLineEdit() + self.zpos.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.zpos.setText(str(self.currentZ)) + self.zpos.returnPressed.connect(self.update_ztext) + self.zpos.setFixedWidth(40) + self.zpos.setFont(self.medfont) + self.l0.addWidget(self.zpos, b, 7, 1, 2) + + # if called with image, load it + if image is not None: + self.filename = image + io._load_image(self, self.filename, load_3D=True) + + self.load_3D = True + + def add_mask(self, points=None, color=(100, 200, 50), dense=True): + # points is list of strokes + + points_all = np.concatenate(points, axis=0) + + # loop over z values + median = [] + zdraw = np.unique(points_all[:, 0]) + zrange = np.arange(zdraw.min(), zdraw.max() + 1, 1, int) + zmin = zdraw.min() + pix = np.zeros((2, 0), "uint16") + mall = np.zeros((len(zrange), self.Ly, self.Lx), "bool") + k = 0 + for z in zdraw: + ars, acs, vrs, vcs = np.zeros(0, "int"), np.zeros(0, "int"), np.zeros( + 0, "int"), np.zeros(0, "int") + for stroke in points: + stroke = np.concatenate(stroke, axis=0).reshape(-1, 4) + iz = stroke[:, 0] == z + vr = stroke[iz, 1] + vc = stroke[iz, 2] + if iz.sum() > 0: + # get points inside drawn points + mask = np.zeros((np.ptp(vr) + 4, np.ptp(vc) + 4), "uint8") + pts = np.stack((vc - vc.min() + 2, vr - vr.min() + 2), + axis=-1)[:, np.newaxis, :] + mask = cv2.fillPoly(mask, [pts], (255, 0, 0)) + ar, ac = np.nonzero(mask) + ar, ac = ar + vr.min() - 2, ac + vc.min() - 2 + # get dense outline + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0].squeeze().T + vr, vc = pvr + vr.min() - 2, pvc + vc.min() - 2 + # concatenate all points + ar, ac = np.hstack((np.vstack((vr, vc)), np.vstack((ar, ac)))) + # if these pixels are overlapping with another cell, reassign them + ioverlap = self.cellpix[z][ar, ac] > 0 + if (~ioverlap).sum() < 8: + print("ERROR: cell too small without overlaps, not drawn") + return None + elif ioverlap.sum() > 0: + ar, ac = ar[~ioverlap], ac[~ioverlap] + # compute outline of new mask + mask = np.zeros((np.ptp(ar) + 4, np.ptp(ac) + 4), "uint8") + mask[ar - ar.min() + 2, ac - ac.min() + 2] = 1 + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = contours[-2][0].squeeze().T + vr, vc = pvr + ar.min() - 2, pvc + ac.min() - 2 + ars = np.concatenate((ars, ar), axis=0) + acs = np.concatenate((acs, ac), axis=0) + vrs = np.concatenate((vrs, vr), axis=0) + vcs = np.concatenate((vcs, vc), axis=0) + self.draw_mask(z, ars, acs, vrs, vcs, color) + + median.append(np.array([np.median(ars), np.median(acs)])) + mall[z - zmin, ars, acs] = True + pix = np.append(pix, np.vstack((ars, acs)), axis=-1) + + mall = mall[:, pix[0].min():pix[0].max() + 1, + pix[1].min():pix[1].max() + 1].astype("float32") + ymin, xmin = pix[0].min(), pix[1].min() + if len(zdraw) > 1: + mall, zfill = interpZ(mall, zdraw - zmin) + for z in zfill: + mask = mall[z].copy() + ar, ac = np.nonzero(mask) + ioverlap = self.cellpix[z + zmin][ar + ymin, ac + xmin] > 0 + if (~ioverlap).sum() < 5: + print("WARNING: stroke on plane %d not included due to overlaps" % + z) + elif ioverlap.sum() > 0: + mask[ar[ioverlap], ac[ioverlap]] = 0 + ar, ac = ar[~ioverlap], ac[~ioverlap] + # compute outline of mask + outlines = masks_to_outlines(mask) + vr, vc = np.nonzero(outlines) + vr, vc = vr + ymin, vc + xmin + ar, ac = ar + ymin, ac + xmin + self.draw_mask(z + zmin, ar, ac, vr, vc, color) + + self.zdraw.append(zdraw) + + return median + + def move_in_Z(self): + if self.loaded: + self.currentZ = min(self.NZ, max(0, int(self.scroll.value()))) + self.zpos.setText(str(self.currentZ)) + self.update_plot() + self.draw_layer() + self.update_layer() + + def make_orthoviews(self): + self.pOrtho, self.imgOrtho, self.layerOrtho = [], [], [] + for j in range(2): + self.pOrtho.append( + pg.ViewBox(lockAspect=True, name=f"plotOrtho{j}", + border=[100, 100, 100], invertY=True, enableMouse=False)) + self.pOrtho[j].setMenuEnabled(False) + + self.imgOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self)) + self.imgOrtho[j].autoDownsample = False + + self.layerOrtho.append(pg.ImageItem(viewbox=self.pOrtho[j], parent=self)) + self.layerOrtho[j].setLevels([0., 255.]) + + #self.pOrtho[j].scene().contextMenuItem = self.pOrtho[j] + self.pOrtho[j].addItem(self.imgOrtho[j]) + self.pOrtho[j].addItem(self.layerOrtho[j]) + self.pOrtho[j].addItem(self.vLineOrtho[j], ignoreBounds=False) + self.pOrtho[j].addItem(self.hLineOrtho[j], ignoreBounds=False) + + self.pOrtho[0].linkView(self.pOrtho[0].YAxis, self.p0) + self.pOrtho[1].linkView(self.pOrtho[1].XAxis, self.p0) + + def add_orthoviews(self): + self.yortho = self.Ly // 2 + self.xortho = self.Lx // 2 + if self.NZ > 1: + self.update_ortho() + + self.win.addItem(self.pOrtho[0], 0, 1, rowspan=1, colspan=1) + self.win.addItem(self.pOrtho[1], 1, 0, rowspan=1, colspan=1) + + qGraphicsGridLayout = self.win.ci.layout + qGraphicsGridLayout.setColumnStretchFactor(0, 2) + qGraphicsGridLayout.setColumnStretchFactor(1, 1) + qGraphicsGridLayout.setRowStretchFactor(0, 2) + qGraphicsGridLayout.setRowStretchFactor(1, 1) + + self.pOrtho[0].setYRange(0, self.Lx) + self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3) + self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3) + self.pOrtho[1].setXRange(0, self.Ly) + + self.p0.addItem(self.vLine, ignoreBounds=False) + self.p0.addItem(self.hLine, ignoreBounds=False) + self.p0.setYRange(0, self.Lx) + self.p0.setXRange(0, self.Ly) + + self.win.show() + self.show() + + def remove_orthoviews(self): + self.win.removeItem(self.pOrtho[0]) + self.win.removeItem(self.pOrtho[1]) + self.p0.removeItem(self.vLine) + self.p0.removeItem(self.hLine) + self.win.show() + self.show() + + def update_crosshairs(self): + self.yortho = min(self.Ly - 1, max(0, int(self.yortho))) + self.xortho = min(self.Lx - 1, max(0, int(self.xortho))) + self.vLine.setPos(self.xortho) + self.hLine.setPos(self.yortho) + self.vLineOrtho[1].setPos(self.xortho) + self.hLineOrtho[1].setPos(self.zc) + self.vLineOrtho[0].setPos(self.zc) + self.hLineOrtho[0].setPos(self.yortho) + + def update_ortho(self): + if self.NZ > 1 and self.orthobtn.isChecked(): + dzcurrent = self.dz + self.dz = min(100, max(3, int(self.dzedit.text()))) + self.zaspect = max(0.01, min(100., float(self.zaspectedit.text()))) + self.dzedit.setText(str(self.dz)) + self.zaspectedit.setText(str(self.zaspect)) + if self.dz != dzcurrent: + self.pOrtho[0].setXRange(-self.dz / 3, self.dz * 2 + self.dz / 3) + self.pOrtho[1].setYRange(-self.dz / 3, self.dz * 2 + self.dz / 3) + dztot = min(self.NZ, self.dz * 2) + y = self.yortho + x = self.xortho + z = self.currentZ + if dztot == self.NZ: + zmin, zmax = 0, self.NZ + else: + if z - self.dz < 0: + zmin = 0 + zmax = zmin + self.dz * 2 + elif z + self.dz >= self.NZ: + zmax = self.NZ + zmin = zmax - self.dz * 2 + else: + zmin, zmax = z - self.dz, z + self.dz + self.zc = z - zmin + self.update_crosshairs() + if self.view == 0 or self.view == 4: + for j in range(2): + if j == 0: + if self.view == 0: + image = self.stack[zmin:zmax, :, x].transpose(1, 0, 2).copy() + else: + image = self.stack_filtered[zmin:zmax, :, + x].transpose(1, 0, 2).copy() + else: + image = self.stack[ + zmin:zmax, + y, :].copy() if self.view == 0 else self.stack_filtered[zmin:zmax, + y, :].copy() + if self.nchan == 1: + # show single channel + image = image[..., 0] + if self.color == 0: + self.imgOrtho[j].setImage(image, autoLevels=False, lut=None) + if self.nchan > 1: + levels = np.array([ + self.saturation[0][self.currentZ], + self.saturation[1][self.currentZ], + self.saturation[2][self.currentZ] + ]) + self.imgOrtho[j].setLevels(levels) + else: + self.imgOrtho[j].setLevels( + self.saturation[0][self.currentZ]) + elif self.color > 0 and self.color < 4: + if self.nchan > 1: + image = image[..., self.color - 1] + self.imgOrtho[j].setImage(image, autoLevels=False, + lut=self.cmap[self.color]) + if self.nchan > 1: + self.imgOrtho[j].setLevels( + self.saturation[self.color - 1][self.currentZ]) + else: + self.imgOrtho[j].setLevels( + self.saturation[0][self.currentZ]) + elif self.color == 4: + if image.ndim > 2: + image = image.astype("float32").mean(axis=2).astype("uint8") + self.imgOrtho[j].setImage(image, autoLevels=False, lut=None) + self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ]) + elif self.color == 5: + if image.ndim > 2: + image = image.astype("float32").mean(axis=2).astype("uint8") + self.imgOrtho[j].setImage(image, autoLevels=False, + lut=self.cmap[0]) + self.imgOrtho[j].setLevels(self.saturation[0][self.currentZ]) + self.pOrtho[0].setAspectLocked(lock=True, ratio=self.zaspect) + self.pOrtho[1].setAspectLocked(lock=True, ratio=1. / self.zaspect) + + else: + image = np.zeros((10, 10), "uint8") + self.imgOrtho[0].setImage(image, autoLevels=False, lut=None) + self.imgOrtho[0].setLevels([0.0, 255.0]) + self.imgOrtho[1].setImage(image, autoLevels=False, lut=None) + self.imgOrtho[1].setLevels([0.0, 255.0]) + + zrange = zmax - zmin + self.layer_ortho = [ + np.zeros((self.Ly, zrange, 4), "uint8"), + np.zeros((zrange, self.Lx, 4), "uint8") + ] + if self.masksOn: + for j in range(2): + if j == 0: + cp = self.cellpix[zmin:zmax, :, x].T + else: + cp = self.cellpix[zmin:zmax, y] + self.layer_ortho[j][..., :3] = self.cellcolors[cp, :] + self.layer_ortho[j][..., 3] = self.opacity * (cp > 0).astype("uint8") + if self.selected > 0: + self.layer_ortho[j][cp == self.selected] = np.array( + [255, 255, 255, self.opacity]) + + if self.outlinesOn: + for j in range(2): + if j == 0: + op = self.outpix[zmin:zmax, :, x].T + else: + op = self.outpix[zmin:zmax, y] + self.layer_ortho[j][op > 0] = np.array(self.outcolor).astype("uint8") + + for j in range(2): + self.layerOrtho[j].setImage(self.layer_ortho[j]) + self.win.show() + self.show() + + def toggle_ortho(self): + if self.orthobtn.isChecked(): + self.add_orthoviews() + else: + self.remove_orthoviews() + + def plot_clicked(self, event): + if event.button()==QtCore.Qt.LeftButton \ + and not event.modifiers() & (QtCore.Qt.ShiftModifier | QtCore.Qt.AltModifier)\ + and not self.removing_region: + if event.double(): + try: + self.p0.setYRange(0, self.Ly + self.pr) + except: + self.p0.setYRange(0, self.Ly) + self.p0.setXRange(0, self.Lx) + elif self.loaded and not self.in_stroke: + if self.orthobtn.isChecked(): + items = self.win.scene().items(event.scenePos()) + for x in items: + if x == self.p0: + pos = self.p0.mapSceneToView(event.scenePos()) + x = int(pos.x()) + y = int(pos.y()) + if y >= 0 and y < self.Ly and x >= 0 and x < self.Lx: + self.yortho = y + self.xortho = x + self.update_ortho() + + def update_plot(self): + super().update_plot() + if self.NZ > 1 and self.orthobtn.isChecked(): + self.update_ortho() + self.win.show() + self.show() + + def keyPressEvent(self, event): + if self.loaded: + if not (event.modifiers() & + (QtCore.Qt.ControlModifier | QtCore.Qt.ShiftModifier | + QtCore.Qt.AltModifier) or self.in_stroke): + updated = False + if len(self.current_point_set) > 0: + if event.key() == QtCore.Qt.Key_Return: + self.add_set() + if self.NZ > 1: + if event.key() == QtCore.Qt.Key_Left: + self.currentZ = max(0, self.currentZ - 1) + self.scroll.setValue(self.currentZ) + updated = True + elif event.key() == QtCore.Qt.Key_Right: + self.currentZ = min(self.NZ - 1, self.currentZ + 1) + self.scroll.setValue(self.currentZ) + updated = True + else: + nviews = self.ViewDropDown.count() - 1 + nviews += int( + self.ViewDropDown.model().item(self.ViewDropDown.count() - + 1).isEnabled()) + if event.key() == QtCore.Qt.Key_X: + self.MCheckBox.toggle() + if event.key() == QtCore.Qt.Key_Z: + self.OCheckBox.toggle() + if event.key() == QtCore.Qt.Key_Left or event.key( + ) == QtCore.Qt.Key_A: + self.currentZ = max(0, self.currentZ - 1) + self.scroll.setValue(self.currentZ) + updated = True + elif event.key() == QtCore.Qt.Key_Right or event.key( + ) == QtCore.Qt.Key_D: + self.currentZ = min(self.NZ - 1, self.currentZ + 1) + self.scroll.setValue(self.currentZ) + updated = True + elif event.key() == QtCore.Qt.Key_PageDown: + self.view = (self.view + 1) % (nviews) + self.ViewDropDown.setCurrentIndex(self.view) + elif event.key() == QtCore.Qt.Key_PageUp: + self.view = (self.view - 1) % (nviews) + self.ViewDropDown.setCurrentIndex(self.view) + + # can change background or stroke size if cell not finished + if event.key() == QtCore.Qt.Key_Up or event.key() == QtCore.Qt.Key_W: + self.color = (self.color - 1) % (6) + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_Down or event.key( + ) == QtCore.Qt.Key_S: + self.color = (self.color + 1) % (6) + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_R: + if self.color != 1: + self.color = 1 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_G: + if self.color != 2: + self.color = 2 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif event.key() == QtCore.Qt.Key_B: + if self.color != 3: + self.color = 3 + else: + self.color = 0 + self.RGBDropDown.setCurrentIndex(self.color) + elif (event.key() == QtCore.Qt.Key_Comma or + event.key() == QtCore.Qt.Key_Period): + count = self.BrushChoose.count() + gci = self.BrushChoose.currentIndex() + if event.key() == QtCore.Qt.Key_Comma: + gci = max(0, gci - 1) + else: + gci = min(count - 1, gci + 1) + self.BrushChoose.setCurrentIndex(gci) + self.brush_choose() + if not updated: + self.update_plot() + if event.key() == QtCore.Qt.Key_Minus or event.key() == QtCore.Qt.Key_Equal: + self.p0.keyPressEvent(event) + + def update_ztext(self): + zpos = self.currentZ + try: + zpos = int(self.zpos.text()) + except: + print("ERROR: zposition is not a number") + self.currentZ = max(0, min(self.NZ - 1, zpos)) + self.zpos.setText(str(self.currentZ)) + self.scroll.setValue(self.currentZ) diff --git a/models/seg_post_model/cellpose/gui/guihelpwindowtext.html b/models/seg_post_model/cellpose/gui/guihelpwindowtext.html new file mode 100644 index 0000000000000000000000000000000000000000..b29aa3a26f4c28fcafdb202c0d0d9efe713cb8d8 --- /dev/null +++ b/models/seg_post_model/cellpose/gui/guihelpwindowtext.html @@ -0,0 +1,143 @@ + +

+ Main GUI mouse controls: +

+
    +
  • Pan = left-click + drag
  • +
  • Zoom = scroll wheel (or +/= and - buttons)
  • +
  • Full view = double left-click
  • +
  • Select mask = left-click on mask
  • +
  • Delete mask = Ctrl (or COMMAND on Mac) + + left-click +
  • +
  • Merge masks = Alt + left-click (will merge + last two) +
  • +
  • Start draw mask = right-click
  • +
  • End draw mask = right-click, or return to + circle at beginning +
  • +
+

Overlaps in masks are NOT allowed. If you + draw a mask on top of another mask, it is cropped so that it doesn’t overlap with the old mask. Masks in 2D + should be single strokes (single stroke is checked). If you want to draw masks in 3D (experimental), then + you can turn this option off and draw a stroke on each plane with the cell and then press ENTER. 3D + labelling will fill in planes that you have not labelled so that you do not have to as densely label. +

+

!NOTE!: The GUI automatically saves after + you draw a mask in 2D but NOT after 3D mask drawing and NOT after segmentation. Save in the file menu or + with Ctrl+S. The output file is in the same folder as the loaded image with _seg.npy appended. +

+ +

Bulk Mask Deletion + Clicking the 'delete multiple' button will allow you to select and delete multiple masks at once. + Masks can be deselected by clicking on them again. Once you have selected all the masks you want to delete, + click the 'done' button to delete them. +
+
+ Alternatively, you can create a rectangular region to delete a regions of masks by clicking the + 'delete multiple' button, and then moving and/or resizing the region to select the masks you want to delete. + Once you have selected the masks you want to delete, click the 'done' button to delete them. +
+
+ At any point in the process, you can click the 'cancel' button to cancel the bulk deletion. +

+
+ +
+
+ FYI there are tooltips throughout the GUI (hover over text to see) +
+ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +
Keyboard shortcutsDescription
=/+ button // - buttonzoom in // zoom out
CTRL+Zundo previously drawn mask/stroke
CTRL+Yundo remove mask
CTRL+0clear all masks
CTRL+Lload image (can alternatively drag and drop image)
CTRL+SSAVE MASKS IN IMAGE to _seg.npy file
CTRL+Ttrain model using _seg.npy files in folder +
CTRL+Pload _seg.npy file (note: it will load automatically with image if it exists)
CTRL+Mload masks file (must be same size as image with 0 for NO mask, and 1,2,3… for masks)
CTRL+Nsave masks as PNG
CTRL+Rsave ROIs to native ImageJ ROI format
CTRL+Fsave flows to image file
A/D or LEFT/RIGHTcycle through images in current directory
W/S or UP/DOWNchange color (RGB/gray/red/green/blue)
R / G / Btoggle between RGB and Red or Green or Blue
PAGE-UP / PAGE-DOWNchange to flows and cell prob views (if segmentation computed)
Xturn masks ON or OFF
Ztoggle outlines ON or OFF
, / .increase / decrease brush size for drawing masks
+

Segmentation options + (2D only)

+

use GPU: if you have specially + installed the cuda version of torch, then you can activate this. Due to the size of the + transformer network, it will greatly speed up the processing time.

+

There are no channel options + in v4.0.1+ since all 3 channels are used for segmentation.

+
diff --git a/models/seg_post_model/cellpose/gui/guiparts.py b/models/seg_post_model/cellpose/gui/guiparts.py new file mode 100644 index 0000000000000000000000000000000000000000..b9e52721a02c9f7f40108f7a86eec4d30f914158 --- /dev/null +++ b/models/seg_post_model/cellpose/gui/guiparts.py @@ -0,0 +1,793 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +from qtpy import QtGui, QtCore +from qtpy.QtGui import QPixmap, QDoubleValidator +from qtpy.QtWidgets import QWidget, QDialog, QGridLayout, QPushButton, QLabel, QLineEdit, QDialogButtonBox, QComboBox, QCheckBox, QVBoxLayout +import pyqtgraph as pg +import numpy as np +import pathlib, os + + +def stylesheet(): + return """ + QToolTip { + background-color: black; + color: white; + border: black solid 1px + } + QComboBox {color: white; + background-color: rgb(40,40,40);} + QComboBox::item:enabled { color: white; + background-color: rgb(40,40,40); + selection-color: white; + selection-background-color: rgb(50,100,50);} + QComboBox::item:!enabled { + background-color: rgb(40,40,40); + color: rgb(100,100,100); + } + QScrollArea > QWidget > QWidget + { + background: transparent; + border: none; + margin: 0px 0px 0px 0px; + } + + QGroupBox + { border: 1px solid white; color: rgb(255,255,255); + border-radius: 6px; + margin-top: 8px; + padding: 0px 0px;} + + QPushButton:pressed {Text-align: center; + background-color: rgb(150,50,150); + border-color: white; + color:white;} + QToolTip { + background-color: black; + color: white; + border: black solid 1px + } + QPushButton:!pressed {Text-align: center; + background-color: rgb(50,50,50); + border-color: white; + color:white;} + QToolTip { + background-color: black; + color: white; + border: black solid 1px + } + QPushButton:disabled {Text-align: center; + background-color: rgb(30,30,30); + border-color: white; + color:rgb(80,80,80);} + QToolTip { + background-color: black; + color: white; + border: black solid 1px + } + + """ + + +class DarkPalette(QtGui.QPalette): + """Class that inherits from pyqtgraph.QtGui.QPalette and renders dark colours for the application. + (from pykilosort/kilosort4) + """ + + def __init__(self): + QtGui.QPalette.__init__(self) + self.setup() + + def setup(self): + self.setColor(QtGui.QPalette.Window, QtGui.QColor(40, 40, 40)) + self.setColor(QtGui.QPalette.WindowText, QtGui.QColor(255, 255, 255)) + self.setColor(QtGui.QPalette.Base, QtGui.QColor(34, 27, 24)) + self.setColor(QtGui.QPalette.AlternateBase, QtGui.QColor(53, 50, 47)) + self.setColor(QtGui.QPalette.ToolTipBase, QtGui.QColor(255, 255, 255)) + self.setColor(QtGui.QPalette.ToolTipText, QtGui.QColor(255, 255, 255)) + self.setColor(QtGui.QPalette.Text, QtGui.QColor(255, 255, 255)) + self.setColor(QtGui.QPalette.Button, QtGui.QColor(53, 50, 47)) + self.setColor(QtGui.QPalette.ButtonText, QtGui.QColor(255, 255, 255)) + self.setColor(QtGui.QPalette.BrightText, QtGui.QColor(255, 0, 0)) + self.setColor(QtGui.QPalette.Link, QtGui.QColor(42, 130, 218)) + self.setColor(QtGui.QPalette.Highlight, QtGui.QColor(42, 130, 218)) + self.setColor(QtGui.QPalette.HighlightedText, QtGui.QColor(0, 0, 0)) + self.setColor(QtGui.QPalette.Disabled, QtGui.QPalette.Text, + QtGui.QColor(128, 128, 128)) + self.setColor( + QtGui.QPalette.Disabled, + QtGui.QPalette.ButtonText, + QtGui.QColor(128, 128, 128), + ) + self.setColor( + QtGui.QPalette.Disabled, + QtGui.QPalette.WindowText, + QtGui.QColor(128, 128, 128), + ) + + +# def create_channel_choose(): +# # choose channel +# ChannelChoose = [QComboBox(), QComboBox()] +# ChannelLabels = [] +# ChannelChoose[0].addItems(["gray", "red", "green", "blue"]) +# ChannelChoose[1].addItems(["none", "red", "green", "blue"]) +# cstr = ["chan to segment:", "chan2 (optional): "] +# for i in range(2): +# ChannelLabels.append(QLabel(cstr[i])) +# if i == 0: +# ChannelLabels[i].setToolTip( +# "this is the channel in which the cytoplasm or nuclei exist \ +# that you want to segment") +# ChannelChoose[i].setToolTip( +# "this is the channel in which the cytoplasm or nuclei exist \ +# that you want to segment") +# else: +# ChannelLabels[i].setToolTip( +# "if cytoplasm model is chosen, and you also have a \ +# nuclear channel, then choose the nuclear channel for this option") +# ChannelChoose[i].setToolTip( +# "if cytoplasm model is chosen, and you also have a \ +# nuclear channel, then choose the nuclear channel for this option") + +# return ChannelChoose, ChannelLabels + + +class ModelButton(QPushButton): + + def __init__(self, parent, model_name, text): + super().__init__() + self.setEnabled(False) + self.setText(text) + self.setFont(parent.boldfont) + self.clicked.connect(lambda: self.press(parent)) + self.model_name = "cpsam" + + def press(self, parent): + parent.compute_segmentation(model_name="cpsam") + + +class FilterButton(QPushButton): + + def __init__(self, parent, text): + super().__init__() + self.setEnabled(False) + self.model_type = text + self.setText(text) + self.setFont(parent.medfont) + self.clicked.connect(lambda: self.press(parent)) + + def press(self, parent): + if self.model_type == "filter": + parent.restore = "filter" + normalize_params = parent.get_normalize_params() + if (normalize_params["sharpen_radius"] == 0 and + normalize_params["smooth_radius"] == 0 and + normalize_params["tile_norm_blocksize"] == 0): + print( + "GUI_ERROR: no filtering settings on (use custom filter settings)") + parent.restore = None + return + parent.restore = self.model_type + parent.compute_saturation() + # elif self.model_type != "none": + # parent.compute_denoise_model(model_type=self.model_type) + else: + parent.clear_restore() + # parent.set_restore_button() + + +class ObservableVariable(QtCore.QObject): + valueChanged = QtCore.Signal(object) + + def __init__(self, initial=None): + super().__init__() + self._value = initial + + def set(self, new_value): + """ Use this method to get emit the value changing and update the ROI count""" + if new_value != self._value: + self._value = new_value + self.valueChanged.emit(new_value) + + def get(self): + return self._value + + def __call__(self): + return self._value + + def reset(self): + self.set(0) + + def __iadd__(self, amount): + if not isinstance(amount, (int, float)): + raise TypeError("Value must be numeric.") + self.set(self._value + amount) + return self + + def __radd__(self, other): + return other + self._value + + def __add__(self, other): + return other + self._value + + def __isub__(self, amount): + if not isinstance(amount, (int, float)): + raise TypeError("Value must be numeric.") + self.set(self._value - amount) + return self + + def __str__(self): + return str(self._value) + + def __lt__(self, x): + return self._value < x + + def __gt__(self, x): + return self._value > x + + def __eq__(self, x): + return self._value == x + + +class NormalizationSettings(QWidget): + # TODO + pass + + +class SegmentationSettings(QWidget): + """ Container for gui settings. Validation is done automatically so any attributes can + be acessed without concern. + """ + def __init__(self, font): + super().__init__() + + # Put everything in a grid layout: + grid_layout = QGridLayout() + widget_container = QWidget() + widget_container.setLayout(grid_layout) + row = 0 + + ########################### Diameter ########################### + # TODO: Validate inputs + diam_qlabel = QLabel("diameter:") + diam_qlabel.setToolTip("diameter of cells in pixels. If not 30, image will be resized to this") + diam_qlabel.setFont(font) + grid_layout.addWidget(diam_qlabel, row, 0, 1, 2) + self.diameter_box = QLineEdit() + self.diameter_box.setToolTip("diameter of cells in pixels. If not blank, image will be resized relative to 30 pixel cell diameters") + self.diameter_box.setFont(font) + self.diameter_box.setFixedWidth(40) + self.diameter_box.setText(' ') + grid_layout.addWidget(self.diameter_box, row, 2, 1, 2) + + row += 1 + + ########################### Flow threshold ########################### + # TODO: Validate inputs + flow_threshold_qlabel = QLabel("flow\nthreshold:") + flow_threshold_qlabel.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run") + flow_threshold_qlabel.setFont(font) + grid_layout.addWidget(flow_threshold_qlabel, row, 0, 1, 2) + self.flow_threshold_box = QLineEdit() + self.flow_threshold_box.setText("0.4") + self.flow_threshold_box.setFixedWidth(40) + self.flow_threshold_box.setFont(font) + grid_layout.addWidget(self.flow_threshold_box, row, 2, 1, 2) + self.flow_threshold_box.setToolTip("threshold on flow error to accept a mask (set higher to get more cells, e.g. in range from (0.1, 3.0), OR set to 0.0 to turn off so no cells discarded);\n press enter to recompute if model already run") + + ########################### Cellprob threshold ########################### + # TODO: Validate inputs + cellprob_qlabel = QLabel("cellprob\nthreshold:") + cellprob_qlabel.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run") + cellprob_qlabel.setFont(font) + grid_layout.addWidget(cellprob_qlabel, row, 4, 1, 2) + self.cellprob_threshold_box = QLineEdit() + self.cellprob_threshold_box.setText("0.0") + self.cellprob_threshold_box.setFixedWidth(40) + self.cellprob_threshold_box.setFont(font) + self.cellprob_threshold_box.setToolTip("threshold on cellprob output to seed cell masks (set lower to include more pixels or higher to include fewer, e.g. in range from (-6, 6)); \n press enter to recompute if model already run") + grid_layout.addWidget(self.cellprob_threshold_box, row, 6, 1, 2) + + row += 1 + + ########################### Norm percentiles ########################### + norm_percentiles_qlabel = QLabel("norm percentiles:") + norm_percentiles_qlabel.setToolTip("sets normalization percentiles for segmentation and denoising\n(pixels at lower percentile set to 0.0 and at upper set to 1.0 for network)") + norm_percentiles_qlabel.setFont(font) + grid_layout.addWidget(norm_percentiles_qlabel, row, 0, 1, 8) + + row += 1 + validator = QDoubleValidator(0.0, 100.0, 2) + validator.setNotation(QDoubleValidator.StandardNotation) + + low_norm_qlabel = QLabel('lower:') + low_norm_qlabel.setToolTip("pixels at this percentile set to 0 (default 1.0)") + low_norm_qlabel.setFont(font) + grid_layout.addWidget(low_norm_qlabel, row, 0, 1, 2) + self.norm_percentile_low_box = QLineEdit() + self.norm_percentile_low_box.setText("1.0") + self.norm_percentile_low_box.setFont(font) + self.norm_percentile_low_box.setFixedWidth(40) + self.norm_percentile_low_box.setToolTip("pixels at this percentile set to 0 (default 1.0)") + self.norm_percentile_low_box.setValidator(validator) + self.norm_percentile_low_box.editingFinished.connect(self.validate_normalization_range) + grid_layout.addWidget(self.norm_percentile_low_box, row, 2, 1, 1) + + high_norm_qlabel = QLabel('upper:') + high_norm_qlabel.setToolTip("pixels at this percentile set to 1 (default 99.0)") + high_norm_qlabel.setFont(font) + grid_layout.addWidget(high_norm_qlabel, row, 4, 1, 2) + self.norm_percentile_high_box = QLineEdit() + self.norm_percentile_high_box.setText("99.0") + self.norm_percentile_high_box.setFont(font) + self.norm_percentile_high_box.setFixedWidth(40) + self.norm_percentile_high_box.setToolTip("pixels at this percentile set to 1 (default 99.0)") + self.norm_percentile_high_box.setValidator(validator) + self.norm_percentile_high_box.editingFinished.connect(self.validate_normalization_range) + grid_layout.addWidget(self.norm_percentile_high_box, row, 6, 1, 2) + + row += 1 + + ########################### niter ########################### + # TODO: change this to follow the same default logic as 'diameter' above + # TODO: input validation + niter_qlabel = QLabel("niter dynamics:") + niter_qlabel.setFont(font) + niter_qlabel.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria") + grid_layout.addWidget(niter_qlabel, row, 0, 1, 4) + self.niter_box = QLineEdit() + self.niter_box.setText("0") + self.niter_box.setFixedWidth(40) + self.niter_box.setFont(font) + self.niter_box.setToolTip("number of iterations for dynamics (0 uses default based on diameter); use 2000 for bacteria") + grid_layout.addWidget(self.niter_box, row, 4, 1, 2) + + self.setLayout(grid_layout) + + def validate_normalization_range(self): + low_text = self.norm_percentile_low_box.text() + high_text = self.norm_percentile_high_box.text() + + if not low_text or low_text.isspace(): + self.norm_percentile_low_box.setText('1.0') + low_text = '1.0' + elif not high_text or high_text.isspace(): + self.norm_percentile_high_box.setText('1.0') + high_text = '99.0' + + low = float(low_text) + high = float(high_text) + + if low >= high: + # Invalid: show error and mark fields + self.norm_percentile_low_box.setStyleSheet("border: 1px solid red;") + self.norm_percentile_high_box.setStyleSheet("border: 1px solid red;") + else: + # Valid: clear style + self.norm_percentile_low_box.setStyleSheet("") + self.norm_percentile_high_box.setStyleSheet("") + + @property + def low_percentile(self): + """ Also validate the low input by returning 1.0 if text doesn't work """ + low_text = self.norm_percentile_low_box.text() + if not low_text or low_text.isspace(): + self.norm_percentile_low_box.setText('1.0') + low_text = '1.0' + return float(self.norm_percentile_low_box.text()) + + @property + def high_percentile(self): + """ Also validate the high input by returning 99.0 if text doesn't work """ + high_text = self.norm_percentile_high_box.text() + if not high_text or high_text.isspace(): + self.norm_percentile_high_box.setText('99.0') + high_text = '99.0' + return float(self.norm_percentile_high_box.text()) + + @property + def diameter(self): + """ Get the diameter from the diameter box, if box isn't a number return None""" + try: + d = float(self.diameter_box.text()) + except ValueError: + d = None + return d + + @property + def flow_threshold(self): + return float(self.flow_threshold_box.text()) + + @property + def cellprob_threshold(self): + return float(self.cellprob_threshold_box.text()) + + @property + def niter(self): + num = int(self.niter_box.text()) + if num < 1: + self.niter_box.setText('200') + return 200 + else: + return num + + + +class TrainWindow(QDialog): + + def __init__(self, parent, model_strings): + super().__init__(parent) + self.setGeometry(100, 100, 900, 550) + self.setWindowTitle("train settings") + self.win = QWidget(self) + self.l0 = QGridLayout() + self.win.setLayout(self.l0) + + yoff = 0 + qlabel = QLabel("train model w/ images + _seg.npy in current folder >>") + qlabel.setFont(QtGui.QFont("Arial", 10, QtGui.QFont.Bold)) + + qlabel.setAlignment(QtCore.Qt.AlignVCenter) + self.l0.addWidget(qlabel, yoff, 0, 1, 2) + + # choose initial model + yoff += 1 + self.ModelChoose = QComboBox() + self.ModelChoose.addItems(model_strings) + self.ModelChoose.setFixedWidth(150) + self.ModelChoose.setCurrentIndex(parent.training_params["model_index"]) + self.l0.addWidget(self.ModelChoose, yoff, 1, 1, 1) + qlabel = QLabel("initial model: ") + qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.l0.addWidget(qlabel, yoff, 0, 1, 1) + + # choose parameters + labels = ["learning_rate", "weight_decay", "n_epochs", "model_name"] + self.edits = [] + yoff += 1 + for i, label in enumerate(labels): + qlabel = QLabel(label) + qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.l0.addWidget(qlabel, i + yoff, 0, 1, 1) + self.edits.append(QLineEdit()) + self.edits[-1].setText(str(parent.training_params[label])) + self.edits[-1].setFixedWidth(200) + self.l0.addWidget(self.edits[-1], i + yoff, 1, 1, 1) + + yoff += len(labels) + + yoff += 1 + self.use_norm = QCheckBox(f"use restored/filtered image") + self.use_norm.setChecked(True) + + yoff += 2 + qlabel = QLabel( + "(to remove files, click cancel then remove \nfrom folder and reopen train window)" + ) + self.l0.addWidget(qlabel, yoff, 0, 2, 4) + + # click button + yoff += 3 + QBtn = QDialogButtonBox.Ok | QDialogButtonBox.Cancel + self.buttonBox = QDialogButtonBox(QBtn) + self.buttonBox.accepted.connect(lambda: self.accept(parent)) + self.buttonBox.rejected.connect(self.reject) + self.l0.addWidget(self.buttonBox, yoff, 0, 1, 4) + + # list files in folder + qlabel = QLabel("filenames") + qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) + self.l0.addWidget(qlabel, 0, 4, 1, 1) + qlabel = QLabel("# of masks") + qlabel.setFont(QtGui.QFont("Arial", 8, QtGui.QFont.Bold)) + self.l0.addWidget(qlabel, 0, 5, 1, 1) + + for i in range(10): + if i > len(parent.train_files) - 1: + break + elif i == 9 and len(parent.train_files) > 10: + label = "..." + nmasks = "..." + else: + label = os.path.split(parent.train_files[i])[-1] + nmasks = str(parent.train_labels[i].max()) + qlabel = QLabel(label) + self.l0.addWidget(qlabel, i + 1, 4, 1, 1) + qlabel = QLabel(nmasks) + qlabel.setAlignment(QtCore.Qt.AlignRight | QtCore.Qt.AlignVCenter) + self.l0.addWidget(qlabel, i + 1, 5, 1, 1) + + def accept(self, parent): + # set training params + parent.training_params = { + "model_index": self.ModelChoose.currentIndex(), + "learning_rate": float(self.edits[0].text()), + "weight_decay": float(self.edits[1].text()), + "n_epochs": int(self.edits[2].text()), + "model_name": self.edits[3].text(), + #"use_norm": True if self.use_norm.isChecked() else False, + } + self.done(1) + + +class ExampleGUI(QDialog): + + def __init__(self, parent=None): + super(ExampleGUI, self).__init__(parent) + self.setGeometry(100, 100, 1300, 900) + self.setWindowTitle("GUI layout") + self.win = QWidget(self) + layout = QGridLayout() + self.win.setLayout(layout) + guip_path = pathlib.Path.home().joinpath(".cellpose", "cellposeSAM_gui.png") + guip_path = str(guip_path.resolve()) + pixmap = QPixmap(guip_path) + label = QLabel(self) + label.setPixmap(pixmap) + pixmap.scaled + layout.addWidget(label, 0, 0, 1, 1) + + +class HelpWindow(QDialog): + + def __init__(self, parent=None): + super(HelpWindow, self).__init__(parent) + self.setGeometry(100, 50, 700, 1000) + self.setWindowTitle("cellpose help") + self.win = QWidget(self) + layout = QGridLayout() + self.win.setLayout(layout) + + text_file = pathlib.Path(__file__).parent.joinpath("guihelpwindowtext.html") + with open(str(text_file.resolve()), "r") as f: + text = f.read() + + label = QLabel(text) + label.setFont(QtGui.QFont("Arial", 8)) + label.setWordWrap(True) + layout.addWidget(label, 0, 0, 1, 1) + self.show() + + +class TrainHelpWindow(QDialog): + + def __init__(self, parent=None): + super(TrainHelpWindow, self).__init__(parent) + self.setGeometry(100, 50, 700, 300) + self.setWindowTitle("training instructions") + self.win = QWidget(self) + layout = QGridLayout() + self.win.setLayout(layout) + + text_file = pathlib.Path(__file__).parent.joinpath( + "guitrainhelpwindowtext.html") + with open(str(text_file.resolve()), "r") as f: + text = f.read() + + label = QLabel(text) + label.setFont(QtGui.QFont("Arial", 8)) + label.setWordWrap(True) + layout.addWidget(label, 0, 0, 1, 1) + self.show() + + +class ViewBoxNoRightDrag(pg.ViewBox): + + def __init__(self, parent=None, border=None, lockAspect=False, enableMouse=True, + invertY=False, enableMenu=True, name=None, invertX=False): + pg.ViewBox.__init__(self, None, border, lockAspect, enableMouse, invertY, + enableMenu, name, invertX) + self.parent = parent + self.axHistoryPointer = -1 + + def keyPressEvent(self, ev): + """ + This routine should capture key presses in the current view box. + The following events are implemented: + +/= : moves forward in the zooming stack (if it exists) + - : moves backward in the zooming stack (if it exists) + + """ + ev.accept() + if ev.text() == "-": + self.scaleBy([1.1, 1.1]) + elif ev.text() in ["+", "="]: + self.scaleBy([0.9, 0.9]) + else: + ev.ignore() + + +class ImageDraw(pg.ImageItem): + """ + **Bases:** :class:`GraphicsObject ` + GraphicsObject displaying an image. Optimized for rapid update (ie video display). + This item displays either a 2D numpy array (height, width) or + a 3D array (height, width, RGBa). This array is optionally scaled (see + :func:`setLevels `) and/or colored + with a lookup table (see :func:`setLookupTable `) + before being displayed. + ImageItem is frequently used in conjunction with + :class:`HistogramLUTItem ` or + :class:`HistogramLUTWidget ` to provide a GUI + for controlling the levels and lookup table used to display the image. + """ + + sigImageChanged = QtCore.Signal() + + def __init__(self, image=None, viewbox=None, parent=None, **kargs): + super(ImageDraw, self).__init__() + self.levels = np.array([0, 255]) + self.lut = None + self.autoDownsample = False + self.axisOrder = "row-major" + self.removable = False + + self.parent = parent + self.setDrawKernel(kernel_size=self.parent.brush_size) + self.parent.current_stroke = [] + self.parent.in_stroke = False + + def mouseClickEvent(self, ev): + if (self.parent.masksOn or + self.parent.outlinesOn) and not self.parent.removing_region: + is_right_click = ev.button() == QtCore.Qt.RightButton + if self.parent.loaded \ + and (is_right_click or ev.modifiers() & QtCore.Qt.ShiftModifier and not ev.double())\ + and not self.parent.deleting_multiple: + if not self.parent.in_stroke: + ev.accept() + self.create_start(ev.pos()) + self.parent.stroke_appended = False + self.parent.in_stroke = True + self.drawAt(ev.pos(), ev) + else: + ev.accept() + self.end_stroke() + self.parent.in_stroke = False + elif not self.parent.in_stroke: + y, x = int(ev.pos().y()), int(ev.pos().x()) + if y >= 0 and y < self.parent.Ly and x >= 0 and x < self.parent.Lx: + if ev.button() == QtCore.Qt.LeftButton and not ev.double(): + idx = self.parent.cellpix[self.parent.currentZ][y, x] + if idx > 0: + if ev.modifiers() & QtCore.Qt.ControlModifier: + # delete mask selected + self.parent.remove_cell(idx) + elif ev.modifiers() & QtCore.Qt.AltModifier: + self.parent.merge_cells(idx) + elif self.parent.masksOn and not self.parent.deleting_multiple: + self.parent.unselect_cell() + self.parent.select_cell(idx) + elif self.parent.deleting_multiple: + if idx in self.parent.removing_cells_list: + self.parent.unselect_cell_multi(idx) + self.parent.removing_cells_list.remove(idx) + else: + self.parent.select_cell_multi(idx) + self.parent.removing_cells_list.append(idx) + + elif self.parent.masksOn and not self.parent.deleting_multiple: + self.parent.unselect_cell() + + def mouseDragEvent(self, ev): + ev.ignore() + return + + def hoverEvent(self, ev): + if self.parent.in_stroke: + if self.parent.in_stroke: + # continue stroke if not at start + self.drawAt(ev.pos()) + if self.is_at_start(ev.pos()): + self.end_stroke() + else: + ev.acceptClicks(QtCore.Qt.RightButton) + + def create_start(self, pos): + self.scatter = pg.ScatterPlotItem([pos.x()], [pos.y()], pxMode=False, + pen=pg.mkPen(color=(255, 0, 0), + width=self.parent.brush_size), + size=max(3 * 2, + self.parent.brush_size * 1.8 * 2), + brush=None) + self.parent.p0.addItem(self.scatter) + + def is_at_start(self, pos): + thresh_out = max(6, self.parent.brush_size * 3) + thresh_in = max(3, self.parent.brush_size * 1.8) + # first check if you ever left the start + if len(self.parent.current_stroke) > 3: + stroke = np.array(self.parent.current_stroke) + dist = (((stroke[1:, 1:] - + stroke[:1, 1:][np.newaxis, :, :])**2).sum(axis=-1))**0.5 + dist = dist.flatten() + has_left = (dist > thresh_out).nonzero()[0] + if len(has_left) > 0: + first_left = np.sort(has_left)[0] + has_returned = (dist[max(4, first_left + 1):] < thresh_in).sum() + if has_returned > 0: + return True + else: + return False + else: + return False + + def end_stroke(self): + self.parent.p0.removeItem(self.scatter) + if not self.parent.stroke_appended: + self.parent.strokes.append(self.parent.current_stroke) + self.parent.stroke_appended = True + self.parent.current_stroke = np.array(self.parent.current_stroke) + ioutline = self.parent.current_stroke[:, 3] == 1 + self.parent.current_point_set.append( + list(self.parent.current_stroke[ioutline])) + self.parent.current_stroke = [] + if self.parent.autosave: + self.parent.add_set() + if len(self.parent.current_point_set) and len( + self.parent.current_point_set[0]) > 0 and self.parent.autosave: + self.parent.add_set() + self.parent.in_stroke = False + + def tabletEvent(self, ev): + pass + + def drawAt(self, pos, ev=None): + mask = self.strokemask + stroke = self.parent.current_stroke + pos = [int(pos.y()), int(pos.x())] + dk = self.drawKernel + kc = self.drawKernelCenter + sx = [0, dk.shape[0]] + sy = [0, dk.shape[1]] + tx = [pos[0] - kc[0], pos[0] - kc[0] + dk.shape[0]] + ty = [pos[1] - kc[1], pos[1] - kc[1] + dk.shape[1]] + kcent = kc.copy() + if tx[0] <= 0: + sx[0] = 0 + sx[1] = kc[0] + 1 + tx = sx + kcent[0] = 0 + if ty[0] <= 0: + sy[0] = 0 + sy[1] = kc[1] + 1 + ty = sy + kcent[1] = 0 + if tx[1] >= self.parent.Ly - 1: + sx[0] = dk.shape[0] - kc[0] - 1 + sx[1] = dk.shape[0] + tx[0] = self.parent.Ly - kc[0] - 1 + tx[1] = self.parent.Ly + kcent[0] = tx[1] - tx[0] - 1 + if ty[1] >= self.parent.Lx - 1: + sy[0] = dk.shape[1] - kc[1] - 1 + sy[1] = dk.shape[1] + ty[0] = self.parent.Lx - kc[1] - 1 + ty[1] = self.parent.Lx + kcent[1] = ty[1] - ty[0] - 1 + + ts = (slice(tx[0], tx[1]), slice(ty[0], ty[1])) + ss = (slice(sx[0], sx[1]), slice(sy[0], sy[1])) + self.image[ts] = mask[ss] + + for ky, y in enumerate(np.arange(ty[0], ty[1], 1, int)): + for kx, x in enumerate(np.arange(tx[0], tx[1], 1, int)): + iscent = np.logical_and(kx == kcent[0], ky == kcent[1]) + stroke.append([self.parent.currentZ, x, y, iscent]) + self.updateImage() + + def setDrawKernel(self, kernel_size=3): + bs = kernel_size + kernel = np.ones((bs, bs), np.uint8) + self.drawKernel = kernel + self.drawKernelCenter = [ + int(np.floor(kernel.shape[0] / 2)), + int(np.floor(kernel.shape[1] / 2)) + ] + onmask = 255 * kernel[:, :, np.newaxis] + offmask = np.zeros((bs, bs, 1)) + opamask = 100 * kernel[:, :, np.newaxis] + self.redmask = np.concatenate((onmask, offmask, offmask, onmask), axis=-1) + self.strokemask = np.concatenate((onmask, offmask, onmask, opamask), axis=-1) diff --git a/models/seg_post_model/cellpose/gui/guitrainhelpwindowtext.html b/models/seg_post_model/cellpose/gui/guitrainhelpwindowtext.html new file mode 100644 index 0000000000000000000000000000000000000000..f198359113ba0e549f5174e564e549034fbacfb0 --- /dev/null +++ b/models/seg_post_model/cellpose/gui/guitrainhelpwindowtext.html @@ -0,0 +1,25 @@ + + Check out this video to learn the process. +
    +
  1. Drag and drop an image from a folder of images with a similar style (like similar cell types).
  2. +
  3. Run the built-in models on one of the images using the "model zoo" and find the one that works best for your + data. Make sure that if you have a nuclear channel you have selected it for CHAN2. +
  4. +
  5. Fix the labelling by drawing new ROIs (right-click) and deleting incorrect ones (CTRL+click). The GUI + autosaves any manual changes (but does not autosave after running the model, for that click CTRL+S). The + segmentation is saved in a "_seg.npy" file. +
  6. +
  7. Go to the "Models" menu in the File bar at the top and click "Train new model..." or use shortcut CTRL+T. +
  8. +
  9. Choose the pretrained model to start the training from (the model you used in #2), and type in the model + name that you want to use. The other parameters should work well in general for most data types. Then click + OK. +
  10. +
  11. The model will train (much faster if you have a GPU) and then auto-run on the next image in the folder. + Next you can repeat #3-#5 as many times as is necessary. +
  12. +
  13. The trained model is available to use in the future in the GUI in the "custom model" section and is saved + in your image folder. +
  14. +
+
\ No newline at end of file diff --git a/models/seg_post_model/cellpose/gui/io.py b/models/seg_post_model/cellpose/gui/io.py new file mode 100644 index 0000000000000000000000000000000000000000..5b8f604b253b4e5bca50092d669e925f86f93c07 --- /dev/null +++ b/models/seg_post_model/cellpose/gui/io.py @@ -0,0 +1,634 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +import os, gc +import numpy as np +import cv2 +import fastremap + +from ..io import imread, imread_2D, imread_3D, imsave, outlines_to_text, add_model, remove_model, save_rois +from ..models import normalize_default, MODEL_DIR, MODEL_LIST_PATH, get_user_models +from ..utils import masks_to_outlines, outlines_list + +try: + import qtpy + from qtpy.QtWidgets import QFileDialog + GUI = True +except: + GUI = False + +try: + import matplotlib.pyplot as plt + MATPLOTLIB = True +except: + MATPLOTLIB = False + + +def _init_model_list(parent): + MODEL_DIR.mkdir(parents=True, exist_ok=True) + parent.model_list_path = MODEL_LIST_PATH + parent.model_strings = get_user_models() + + +def _add_model(parent, filename=None, load_model=True): + if filename is None: + name = QFileDialog.getOpenFileName(parent, "Add model to GUI") + filename = name[0] + add_model(filename) + fname = os.path.split(filename)[-1] + parent.ModelChooseC.addItems([fname]) + parent.model_strings.append(fname) + + for ind, model_string in enumerate(parent.model_strings[:-1]): + if model_string == fname: + _remove_model(parent, ind=ind + 1, verbose=False) + + parent.ModelChooseC.setCurrentIndex(len(parent.model_strings)) + if load_model: + parent.model_choose(custom=True) + + +def _remove_model(parent, ind=None, verbose=True): + if ind is None: + ind = parent.ModelChooseC.currentIndex() + if ind > 0: + ind -= 1 + parent.ModelChooseC.removeItem(ind + 1) + del parent.model_strings[ind] + # remove model from txt path + modelstr = parent.ModelChooseC.currentText() + remove_model(modelstr) + if len(parent.model_strings) > 0: + parent.ModelChooseC.setCurrentIndex(len(parent.model_strings)) + else: + parent.ModelChooseC.setCurrentIndex(0) + else: + print("ERROR: no model selected to delete") + + +def _get_train_set(image_names): + """ get training data and labels for images in current folder image_names""" + train_data, train_labels, train_files = [], [], [] + restore = None + normalize_params = normalize_default + for image_name_full in image_names: + image_name = os.path.splitext(image_name_full)[0] + label_name = None + if os.path.exists(image_name + "_seg.npy"): + dat = np.load(image_name + "_seg.npy", allow_pickle=True).item() + masks = dat["masks"].squeeze() + if masks.ndim == 2: + fastremap.renumber(masks, in_place=True) + label_name = image_name + "_seg.npy" + else: + print(f"GUI_INFO: _seg.npy found for {image_name} but masks.ndim!=2") + if "img_restore" in dat: + data = dat["img_restore"].squeeze() + restore = dat["restore"] + else: + data = imread(image_name_full) + normalize_params = dat[ + "normalize_params"] if "normalize_params" in dat else normalize_default + if label_name is not None: + train_files.append(image_name_full) + train_data.append(data) + train_labels.append(masks) + if restore: + print(f"GUI_INFO: using {restore} images (dat['img_restore'])") + return train_data, train_labels, train_files, restore, normalize_params + + +def _load_image(parent, filename=None, load_seg=True, load_3D=False): + """ load image with filename; if None, open QFileDialog + if image is grey change view to default to grey scale + """ + + if parent.load_3D: + load_3D = True + + if filename is None: + name = QFileDialog.getOpenFileName(parent, "Load image") + filename = name[0] + if filename == "": + return + manual_file = os.path.splitext(filename)[0] + "_seg.npy" + load_mask = False + if load_seg: + if os.path.isfile(manual_file) and not parent.autoloadMasks.isChecked(): + if filename is not None: + image = (imread_2D(filename) if not load_3D else + imread_3D(filename)) + else: + image = None + _load_seg(parent, manual_file, image=image, image_file=filename, + load_3D=load_3D) + return + elif parent.autoloadMasks.isChecked(): + mask_file = os.path.splitext(filename)[0] + "_masks" + os.path.splitext( + filename)[-1] + mask_file = os.path.splitext(filename)[ + 0] + "_masks.tif" if not os.path.isfile(mask_file) else mask_file + load_mask = True if os.path.isfile(mask_file) else False + try: + print(f"GUI_INFO: loading image: {filename}") + if not load_3D: + image = imread_2D(filename) + else: + image = imread_3D(filename) + parent.loaded = True + except Exception as e: + print("ERROR: images not compatible") + print(f"ERROR: {e}") + + if parent.loaded: + parent.reset() + parent.filename = filename + filename = os.path.split(parent.filename)[-1] + _initialize_images(parent, image, load_3D=load_3D) + parent.loaded = True + parent.enable_buttons() + if load_mask: + _load_masks(parent, filename=mask_file) + + # check if gray and adjust viewer: + if len(np.unique(image[..., 1:])) == 1: + parent.color = 4 + parent.RGBDropDown.setCurrentIndex(4) # gray + parent.update_plot() + + +def _initialize_images(parent, image, load_3D=False): + """ format image for GUI + + assumes image is Z x W x H x C + + """ + load_3D = parent.load_3D if load_3D is False else load_3D + + parent.stack = image + print(f"GUI_INFO: image shape: {image.shape}") + if load_3D: + parent.NZ = len(parent.stack) + parent.scroll.setMaximum(parent.NZ - 1) + else: + parent.NZ = 1 + parent.stack = parent.stack[np.newaxis, ...] + + img_min = image.min() + img_max = image.max() + parent.stack = parent.stack.astype(np.float32) + parent.stack -= img_min + if img_max > img_min + 1e-3: + parent.stack /= (img_max - img_min) + parent.stack *= 255 + + if load_3D: + print("GUI_INFO: converted to float and normalized values to 0.0->255.0") + + del image + gc.collect() + + parent.imask = 0 + parent.Ly, parent.Lx = parent.stack.shape[-3:-1] + parent.Ly0, parent.Lx0 = parent.stack.shape[-3:-1] + parent.layerz = 255 * np.ones((parent.Ly, parent.Lx, 4), "uint8") + if hasattr(parent, "stack_filtered"): + parent.Lyr, parent.Lxr = parent.stack_filtered.shape[-3:-1] + elif parent.restore and "upsample" in parent.restore: + parent.Lyr, parent.Lxr = int(parent.Ly * parent.ratio), int(parent.Lx * + parent.ratio) + else: + parent.Lyr, parent.Lxr = parent.Ly, parent.Lx + parent.clear_all() + + if not hasattr(parent, "stack_filtered") and parent.restore: + print("GUI_INFO: no 'img_restore' found, applying current settings") + parent.compute_restore() + + if parent.autobtn.isChecked(): + if parent.restore is None or parent.restore != "filter": + print( + "GUI_INFO: normalization checked: computing saturation levels (and optionally filtered image)" + ) + parent.compute_saturation() + # elif len(parent.saturation) != parent.NZ: + # parent.saturation = [] + # for r in range(3): + # parent.saturation.append([]) + # for n in range(parent.NZ): + # parent.saturation[-1].append([0, 255]) + # parent.sliders[r].setValue([0, 255]) + parent.compute_scale() + parent.track_changes = [] + + if load_3D: + parent.currentZ = int(np.floor(parent.NZ / 2)) + parent.scroll.setValue(parent.currentZ) + parent.zpos.setText(str(parent.currentZ)) + else: + parent.currentZ = 0 + + +def _load_seg(parent, filename=None, image=None, image_file=None, load_3D=False): + """ load *_seg.npy with filename; if None, open QFileDialog """ + if filename is None: + name = QFileDialog.getOpenFileName(parent, "Load labelled data", filter="*.npy") + filename = name[0] + try: + dat = np.load(filename, allow_pickle=True).item() + # check if there are keys in filename + dat["outlines"] + parent.loaded = True + except: + parent.loaded = False + print("ERROR: not NPY") + return + + parent.reset() + if image is None: + found_image = False + if "filename" in dat: + parent.filename = dat["filename"] + if os.path.isfile(parent.filename): + parent.filename = dat["filename"] + found_image = True + else: + imgname = os.path.split(parent.filename)[1] + root = os.path.split(filename)[0] + parent.filename = root + "/" + imgname + if os.path.isfile(parent.filename): + found_image = True + if found_image: + try: + print(parent.filename) + image = (imread_2D(parent.filename) if not load_3D else + imread_3D(parent.filename)) + except: + parent.loaded = False + found_image = False + print("ERROR: cannot find image file, loading from npy") + if not found_image: + parent.filename = filename[:-8] + print(parent.filename) + if "img" in dat: + image = dat["img"] + else: + print("ERROR: no image file found and no image in npy") + return + else: + parent.filename = image_file + + parent.restore = None + parent.ratio = 1. + + if "normalize_params" in dat: + parent.set_normalize_params(dat["normalize_params"]) + + _initialize_images(parent, image, load_3D=load_3D) + print(parent.stack.shape) + + if "outlines" in dat: + if isinstance(dat["outlines"], list): + # old way of saving files + dat["outlines"] = dat["outlines"][::-1] + for k, outline in enumerate(dat["outlines"]): + if "colors" in dat: + color = dat["colors"][k] + else: + col_rand = np.random.randint(1000) + color = parent.colormap[col_rand, :3] + median = parent.add_mask(points=outline, color=color) + if median is not None: + parent.cellcolors = np.append(parent.cellcolors, + color[np.newaxis, :], axis=0) + parent.ncells += 1 + else: + if dat["masks"].min() == -1: + dat["masks"] += 1 + dat["outlines"] += 1 + parent.ncells.set(dat["masks"].max()) + if "colors" in dat and len(dat["colors"]) == dat["masks"].max(): + colors = dat["colors"] + else: + colors = parent.colormap[:parent.ncells.get(), :3] + + _masks_to_gui(parent, dat["masks"], outlines=dat["outlines"], colors=colors) + + parent.draw_layer() + + if "manual_changes" in dat: + parent.track_changes = dat["manual_changes"] + print("GUI_INFO: loaded in previous changes") + if "zdraw" in dat: + parent.zdraw = dat["zdraw"] + else: + parent.zdraw = [None for n in range(parent.ncells.get())] + parent.loaded = True + else: + parent.clear_all() + + parent.ismanual = np.zeros(parent.ncells.get(), bool) + if "ismanual" in dat: + if len(dat["ismanual"]) == parent.ncells: + parent.ismanual = dat["ismanual"] + + if "current_channel" in dat: + parent.color = (dat["current_channel"] + 2) % 5 + parent.RGBDropDown.setCurrentIndex(parent.color) + + if "flows" in dat: + parent.flows = dat["flows"] + try: + if parent.flows[0].shape[-3] != dat["masks"].shape[-2]: + Ly, Lx = dat["masks"].shape[-2:] + for i in range(len(parent.flows)): + parent.flows[i] = cv2.resize( + parent.flows[i].squeeze(), (Lx, Ly), + interpolation=cv2.INTER_NEAREST)[np.newaxis, ...] + if parent.NZ == 1: + parent.recompute_masks = True + else: + parent.recompute_masks = False + + except: + try: + if len(parent.flows[0]) > 0: + parent.flows = parent.flows[0] + except: + parent.flows = [[], [], [], [], [[]]] + parent.recompute_masks = False + + parent.enable_buttons() + parent.update_layer() + del dat + gc.collect() + + +def _load_masks(parent, filename=None): + """ load zeros-based masks (0=no cell, 1=cell 1, ...) """ + if filename is None: + name = QFileDialog.getOpenFileName(parent, "Load masks (PNG or TIFF)") + filename = name[0] + print(f"GUI_INFO: loading masks: {filename}") + masks = imread(filename) + outlines = None + if masks.ndim > 3: + # Z x nchannels x Ly x Lx + if masks.shape[-1] > 5: + parent.flows = list(np.transpose(masks[:, :, :, 2:], (3, 0, 1, 2))) + outlines = masks[..., 1] + masks = masks[..., 0] + else: + parent.flows = list(np.transpose(masks[:, :, :, 1:], (3, 0, 1, 2))) + masks = masks[..., 0] + elif masks.ndim == 3: + if masks.shape[-1] < 5: + masks = masks[np.newaxis, :, :, 0] + elif masks.ndim < 3: + masks = masks[np.newaxis, :, :] + # masks should be Z x Ly x Lx + if masks.shape[0] != parent.NZ: + print("ERROR: masks are not same depth (number of planes) as image stack") + return + + _masks_to_gui(parent, masks, outlines) + if parent.ncells > 0: + parent.draw_layer() + parent.toggle_mask_ops() + del masks + gc.collect() + parent.update_layer() + parent.update_plot() + + +def _masks_to_gui(parent, masks, outlines=None, colors=None): + """ masks loaded into GUI """ + # get unique values + shape = masks.shape + if len(fastremap.unique(masks)) != masks.max() + 1: + print("GUI_INFO: renumbering masks") + fastremap.renumber(masks, in_place=True) + outlines = None + masks = masks.reshape(shape) + if masks.ndim == 2: + outlines = None + masks = masks.astype(np.uint16) if masks.max() < 2**16 - 1 else masks.astype( + np.uint32) + if parent.restore and "upsample" in parent.restore: + parent.cellpix_resize = masks.copy() + parent.cellpix = parent.cellpix_resize.copy() + parent.cellpix_orig = cv2.resize( + masks.squeeze(), (parent.Lx0, parent.Ly0), + interpolation=cv2.INTER_NEAREST)[np.newaxis, :, :] + parent.resize = True + else: + parent.cellpix = masks + if parent.cellpix.ndim == 2: + parent.cellpix = parent.cellpix[np.newaxis, :, :] + if parent.restore and "upsample" in parent.restore: + if parent.cellpix_resize.ndim == 2: + parent.cellpix_resize = parent.cellpix_resize[np.newaxis, :, :] + if parent.cellpix_orig.ndim == 2: + parent.cellpix_orig = parent.cellpix_orig[np.newaxis, :, :] + + print(f"GUI_INFO: {masks.max()} masks found") + + # get outlines + if outlines is None: # parent.outlinesOn + parent.outpix = np.zeros_like(parent.cellpix) + if parent.restore and "upsample" in parent.restore: + parent.outpix_orig = np.zeros_like(parent.cellpix_orig) + for z in range(parent.NZ): + outlines = masks_to_outlines(parent.cellpix[z]) + parent.outpix[z] = outlines * parent.cellpix[z] + if parent.restore and "upsample" in parent.restore: + outlines = masks_to_outlines(parent.cellpix_orig[z]) + parent.outpix_orig[z] = outlines * parent.cellpix_orig[z] + if z % 50 == 0 and parent.NZ > 1: + print("GUI_INFO: plane %d outlines processed" % z) + if parent.restore and "upsample" in parent.restore: + parent.outpix_resize = parent.outpix.copy() + else: + parent.outpix = outlines + if parent.restore and "upsample" in parent.restore: + parent.outpix_resize = parent.outpix.copy() + parent.outpix_orig = np.zeros_like(parent.cellpix_orig) + for z in range(parent.NZ): + outlines = masks_to_outlines(parent.cellpix_orig[z]) + parent.outpix_orig[z] = outlines * parent.cellpix_orig[z] + if z % 50 == 0 and parent.NZ > 1: + print("GUI_INFO: plane %d outlines processed" % z) + + if parent.outpix.ndim == 2: + parent.outpix = parent.outpix[np.newaxis, :, :] + if parent.restore and "upsample" in parent.restore: + if parent.outpix_resize.ndim == 2: + parent.outpix_resize = parent.outpix_resize[np.newaxis, :, :] + if parent.outpix_orig.ndim == 2: + parent.outpix_orig = parent.outpix_orig[np.newaxis, :, :] + + parent.ncells.set(parent.cellpix.max()) + colors = parent.colormap[:parent.ncells.get(), :3] if colors is None else colors + print("GUI_INFO: creating cellcolors and drawing masks") + parent.cellcolors = np.concatenate((np.array([[255, 255, 255]]), colors), + axis=0).astype(np.uint8) + if parent.ncells > 0: + parent.draw_layer() + parent.toggle_mask_ops() + parent.ismanual = np.zeros(parent.ncells.get(), bool) + parent.zdraw = list(-1 * np.ones(parent.ncells.get(), np.int16)) + + if hasattr(parent, "stack_filtered"): + parent.ViewDropDown.setCurrentIndex(parent.ViewDropDown.count() - 1) + print("set denoised/filtered view") + else: + parent.ViewDropDown.setCurrentIndex(0) + + +def _save_png(parent): + """ save masks to png or tiff (if 3D) """ + filename = parent.filename + base = os.path.splitext(filename)[0] + if parent.NZ == 1: + if parent.cellpix[0].max() > 65534: + print("GUI_INFO: saving 2D masks to tif (too many masks for PNG)") + imsave(base + "_cp_masks.tif", parent.cellpix[0]) + else: + print("GUI_INFO: saving 2D masks to png") + imsave(base + "_cp_masks.png", parent.cellpix[0].astype(np.uint16)) + else: + print("GUI_INFO: saving 3D masks to tiff") + imsave(base + "_cp_masks.tif", parent.cellpix) + + +def _save_flows(parent): + """ save flows and cellprob to tiff """ + filename = parent.filename + base = os.path.splitext(filename)[0] + print("GUI_INFO: saving flows and cellprob to tiff") + if len(parent.flows) > 0: + imsave(base + "_cp_cellprob.tif", parent.flows[1]) + for i in range(3): + imsave(base + f"_cp_flows_{i}.tif", parent.flows[0][..., i]) + if len(parent.flows) > 2: + imsave(base + "_cp_flows.tif", parent.flows[2]) + print("GUI_INFO: saved flows and cellprob") + else: + print("ERROR: no flows or cellprob found") + + +def _save_rois(parent): + """ save masks as rois in .zip file for ImageJ """ + filename = parent.filename + if parent.NZ == 1: + print( + f"GUI_INFO: saving {parent.cellpix[0].max()} ImageJ ROIs to .zip archive.") + save_rois(parent.cellpix[0], parent.filename) + else: + print("ERROR: cannot save 3D outlines") + + +def _save_outlines(parent): + filename = parent.filename + base = os.path.splitext(filename)[0] + if parent.NZ == 1: + print( + "GUI_INFO: saving 2D outlines to text file, see docs for info to load into ImageJ" + ) + outlines = outlines_list(parent.cellpix[0]) + outlines_to_text(base, outlines) + else: + print("ERROR: cannot save 3D outlines") + + +def _save_sets_with_check(parent): + """ Save masks and update *_seg.npy file. Use this function when saving should be optional + based on the disableAutosave checkbox. Otherwise, use _save_sets """ + if not parent.disableAutosave.isChecked(): + _save_sets(parent) + + +def _save_sets(parent): + """ save masks to *_seg.npy. This function should be used when saving + is forced, e.g. when clicking the save button. Otherwise, use _save_sets_with_check + """ + filename = parent.filename + base = os.path.splitext(filename)[0] + flow_threshold = parent.segmentation_settings.flow_threshold + cellprob_threshold = parent.segmentation_settings.cellprob_threshold + + if parent.NZ > 1: + dat = { + "outlines": + parent.outpix, + "colors": + parent.cellcolors[1:], + "masks": + parent.cellpix, + "current_channel": (parent.color - 2) % 5, + "filename": + parent.filename, + "flows": + parent.flows, + "zdraw": + parent.zdraw, + "model_path": + parent.current_model_path + if hasattr(parent, "current_model_path") else 0, + "flow_threshold": + flow_threshold, + "cellprob_threshold": + cellprob_threshold, + "normalize_params": + parent.get_normalize_params(), + "restore": + parent.restore, + "ratio": + parent.ratio, + "diameter": + parent.segmentation_settings.diameter + } + if parent.restore is not None: + dat["img_restore"] = parent.stack_filtered + else: + dat = { + "outlines": + parent.outpix.squeeze() if parent.restore is None or + not "upsample" in parent.restore else parent.outpix_resize.squeeze(), + "colors": + parent.cellcolors[1:], + "masks": + parent.cellpix.squeeze() if parent.restore is None or + not "upsample" in parent.restore else parent.cellpix_resize.squeeze(), + "filename": + parent.filename, + "flows": + parent.flows, + "ismanual": + parent.ismanual, + "manual_changes": + parent.track_changes, + "model_path": + parent.current_model_path + if hasattr(parent, "current_model_path") else 0, + "flow_threshold": + flow_threshold, + "cellprob_threshold": + cellprob_threshold, + "normalize_params": + parent.get_normalize_params(), + "restore": + parent.restore, + "ratio": + parent.ratio, + "diameter": + parent.segmentation_settings.diameter + } + if parent.restore is not None: + dat["img_restore"] = parent.stack_filtered + try: + np.save(base + "_seg.npy", dat) + print("GUI_INFO: %d ROIs saved to %s" % (parent.ncells.get(), base + "_seg.npy")) + except Exception as e: + print(f"ERROR: {e}") + del dat diff --git a/models/seg_post_model/cellpose/gui/make_train.py b/models/seg_post_model/cellpose/gui/make_train.py new file mode 100644 index 0000000000000000000000000000000000000000..add945e066d798a6b15a97c9d8162b9afdbe13a6 --- /dev/null +++ b/models/seg_post_model/cellpose/gui/make_train.py @@ -0,0 +1,107 @@ +import os, argparse +import numpy as np +from cellpose import io, transforms + + +def main(): + parser = argparse.ArgumentParser(description='Make slices of XYZ image data for training. Assumes image is ZXYC unless specified otherwise using --channel_axis and --z_axis') + + input_img_args = parser.add_argument_group("input image arguments") + input_img_args.add_argument('--dir', default=[], type=str, + help='folder containing data to run or train on.') + input_img_args.add_argument( + '--image_path', default=[], type=str, help= + 'if given and --dir not given, run on single image instead of folder (cannot train with this option)' + ) + input_img_args.add_argument( + '--look_one_level_down', action='store_true', + help='run processing on all subdirectories of current folder') + input_img_args.add_argument('--img_filter', default=[], type=str, + help='end string for images to run on') + input_img_args.add_argument( + '--channel_axis', default=-1, type=int, + help='axis of image which corresponds to image channels') + input_img_args.add_argument('--z_axis', default=0, type=int, + help='axis of image which corresponds to Z dimension') + input_img_args.add_argument( + '--chan', default=0, type=int, help= + 'Deprecated') + input_img_args.add_argument( + '--chan2', default=0, type=int, help= + 'Deprecated' + ) + input_img_args.add_argument('--invert', action='store_true', + help='invert grayscale channel') + input_img_args.add_argument( + '--all_channels', action='store_true', help= + 'deprecated') + input_img_args.add_argument("--anisotropy", required=False, default=1.0, type=float, + help="anisotropy of volume in 3D") + + + # algorithm settings + algorithm_args = parser.add_argument_group("algorithm arguments") + algorithm_args.add_argument('--sharpen_radius', required=False, default=0.0, + type=float, help='high-pass filtering radius. Default: %(default)s') + algorithm_args.add_argument('--tile_norm', required=False, default=0, type=int, + help='tile normalization block size. Default: %(default)s') + algorithm_args.add_argument('--nimg_per_tif', required=False, default=10, type=int, + help='number of crops in XY to save per tiff. Default: %(default)s') + algorithm_args.add_argument('--crop_size', required=False, default=512, type=int, + help='size of random crop to save. Default: %(default)s') + + args = parser.parse_args() + + # find images + if len(args.img_filter) > 0: + imf = args.img_filter + else: + imf = None + + if len(args.dir) > 0: + image_names = io.get_image_files(args.dir, "_masks", imf=imf, + look_one_level_down=args.look_one_level_down) + dirname = args.dir + else: + if os.path.exists(args.image_path): + image_names = [args.image_path] + dirname = os.path.split(args.image_path)[0] + else: + raise ValueError(f"ERROR: no file found at {args.image_path}") + + np.random.seed(0) + nimg_per_tif = args.nimg_per_tif + crop_size = args.crop_size + os.makedirs(os.path.join(dirname, 'train/'), exist_ok=True) + pm = [(0, 1, 2, 3), (2, 0, 1, 3), (1, 0, 2, 3)] + npm = ["YX", "ZY", "ZX"] + for name in image_names: + name0 = os.path.splitext(os.path.split(name)[-1])[0] + img0 = io.imread_3D(name) + try: + img0 = transforms.convert_image(img0, channel_axis=args.channel_axis, + z_axis=args.z_axis, do_3D=True) + except ValueError: + print('Error converting image. Did you provide the correct --channel_axis and --z_axis ?') + + for p in range(3): + img = img0.transpose(pm[p]).copy() + print(npm[p], img[0].shape) + Ly, Lx = img.shape[1:3] + imgs = img[np.random.permutation(img.shape[0])[:args.nimg_per_tif]] + if args.anisotropy > 1.0 and p > 0: + imgs = transforms.resize_image(imgs, Ly=int(args.anisotropy * Ly), Lx=Lx) + for k, img in enumerate(imgs): + if args.tile_norm: + img = transforms.normalize99_tile(img, blocksize=args.tile_norm) + if args.sharpen_radius: + img = transforms.smooth_sharpen_img(img, + sharpen_radius=args.sharpen_radius) + ly = 0 if Ly - crop_size <= 0 else np.random.randint(0, Ly - crop_size) + lx = 0 if Lx - crop_size <= 0 else np.random.randint(0, Lx - crop_size) + io.imsave(os.path.join(dirname, f'train/{name0}_{npm[p]}_{k}.tif'), + img[ly:ly + args.crop_size, lx:lx + args.crop_size].squeeze()) + + +if __name__ == '__main__': + main() diff --git a/models/seg_post_model/cellpose/gui/menus.py b/models/seg_post_model/cellpose/gui/menus.py new file mode 100644 index 0000000000000000000000000000000000000000..9c335fadc5096e59bfbcf742cfc978ced1253572 --- /dev/null +++ b/models/seg_post_model/cellpose/gui/menus.py @@ -0,0 +1,145 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +from qtpy.QtWidgets import QAction +from . import io + + +def mainmenu(parent): + main_menu = parent.menuBar() + file_menu = main_menu.addMenu("&File") + # load processed data + loadImg = QAction("&Load image (*.tif, *.png, *.jpg)", parent) + loadImg.setShortcut("Ctrl+L") + loadImg.triggered.connect(lambda: io._load_image(parent)) + file_menu.addAction(loadImg) + + parent.autoloadMasks = QAction("Autoload masks from _masks.tif file", parent, + checkable=True) + parent.autoloadMasks.setChecked(False) + file_menu.addAction(parent.autoloadMasks) + + parent.disableAutosave = QAction("Disable autosave _seg.npy file", parent, + checkable=True) + parent.disableAutosave.setChecked(False) + file_menu.addAction(parent.disableAutosave) + + parent.loadMasks = QAction("Load &masks (*.tif, *.png, *.jpg)", parent) + parent.loadMasks.setShortcut("Ctrl+M") + parent.loadMasks.triggered.connect(lambda: io._load_masks(parent)) + file_menu.addAction(parent.loadMasks) + parent.loadMasks.setEnabled(False) + + loadManual = QAction("Load &processed/labelled image (*_seg.npy)", parent) + loadManual.setShortcut("Ctrl+P") + loadManual.triggered.connect(lambda: io._load_seg(parent)) + file_menu.addAction(loadManual) + + parent.saveSet = QAction("&Save masks and image (as *_seg.npy)", parent) + parent.saveSet.setShortcut("Ctrl+S") + parent.saveSet.triggered.connect(lambda: io._save_sets(parent)) + file_menu.addAction(parent.saveSet) + parent.saveSet.setEnabled(False) + + parent.savePNG = QAction("Save masks as P&NG/tif", parent) + parent.savePNG.setShortcut("Ctrl+N") + parent.savePNG.triggered.connect(lambda: io._save_png(parent)) + file_menu.addAction(parent.savePNG) + parent.savePNG.setEnabled(False) + + parent.saveOutlines = QAction("Save &Outlines as text for imageJ", parent) + parent.saveOutlines.setShortcut("Ctrl+O") + parent.saveOutlines.triggered.connect(lambda: io._save_outlines(parent)) + file_menu.addAction(parent.saveOutlines) + parent.saveOutlines.setEnabled(False) + + parent.saveROIs = QAction("Save outlines as .zip archive of &ROI files for ImageJ", + parent) + parent.saveROIs.setShortcut("Ctrl+R") + parent.saveROIs.triggered.connect(lambda: io._save_rois(parent)) + file_menu.addAction(parent.saveROIs) + parent.saveROIs.setEnabled(False) + + parent.saveFlows = QAction("Save &Flows and cellprob as tif", parent) + parent.saveFlows.setShortcut("Ctrl+F") + parent.saveFlows.triggered.connect(lambda: io._save_flows(parent)) + file_menu.addAction(parent.saveFlows) + parent.saveFlows.setEnabled(False) + + +def editmenu(parent): + main_menu = parent.menuBar() + edit_menu = main_menu.addMenu("&Edit") + parent.undo = QAction("Undo previous mask/trace", parent) + parent.undo.setShortcut("Ctrl+Z") + parent.undo.triggered.connect(parent.undo_action) + parent.undo.setEnabled(False) + edit_menu.addAction(parent.undo) + + parent.redo = QAction("Undo remove mask", parent) + parent.redo.setShortcut("Ctrl+Y") + parent.redo.triggered.connect(parent.undo_remove_action) + parent.redo.setEnabled(False) + edit_menu.addAction(parent.redo) + + parent.ClearButton = QAction("Clear all masks", parent) + parent.ClearButton.setShortcut("Ctrl+0") + parent.ClearButton.triggered.connect(parent.clear_all) + parent.ClearButton.setEnabled(False) + edit_menu.addAction(parent.ClearButton) + + parent.remcell = QAction("Remove selected cell (Ctrl+CLICK)", parent) + parent.remcell.setShortcut("Ctrl+Click") + parent.remcell.triggered.connect(parent.remove_action) + parent.remcell.setEnabled(False) + edit_menu.addAction(parent.remcell) + + parent.mergecell = QAction("FYI: Merge cells by Alt+Click", parent) + parent.mergecell.setEnabled(False) + edit_menu.addAction(parent.mergecell) + + +def modelmenu(parent): + main_menu = parent.menuBar() + io._init_model_list(parent) + model_menu = main_menu.addMenu("&Models") + parent.addmodel = QAction("Add custom torch model to GUI", parent) + #parent.addmodel.setShortcut("Ctrl+A") + parent.addmodel.triggered.connect(parent.add_model) + parent.addmodel.setEnabled(True) + model_menu.addAction(parent.addmodel) + + parent.removemodel = QAction("Remove selected custom model from GUI", parent) + #parent.removemodel.setShortcut("Ctrl+R") + parent.removemodel.triggered.connect(parent.remove_model) + parent.removemodel.setEnabled(True) + model_menu.addAction(parent.removemodel) + + parent.newmodel = QAction("&Train new model with image+masks in folder", parent) + parent.newmodel.setShortcut("Ctrl+T") + parent.newmodel.triggered.connect(parent.new_model) + parent.newmodel.setEnabled(False) + model_menu.addAction(parent.newmodel) + + openTrainHelp = QAction("Training instructions", parent) + openTrainHelp.triggered.connect(parent.train_help_window) + model_menu.addAction(openTrainHelp) + + +def helpmenu(parent): + main_menu = parent.menuBar() + help_menu = main_menu.addMenu("&Help") + + openHelp = QAction("&Help with GUI", parent) + openHelp.setShortcut("Ctrl+H") + openHelp.triggered.connect(parent.help_window) + help_menu.addAction(openHelp) + + openGUI = QAction("&GUI layout", parent) + openGUI.setShortcut("Ctrl+G") + openGUI.triggered.connect(parent.gui_window) + help_menu.addAction(openGUI) + + openTrainHelp = QAction("Training instructions", parent) + openTrainHelp.triggered.connect(parent.train_help_window) + help_menu.addAction(openTrainHelp) diff --git a/models/seg_post_model/cellpose/io.py b/models/seg_post_model/cellpose/io.py new file mode 100644 index 0000000000000000000000000000000000000000..48184e6540a942efb45f20740ed77c9ddbf0c1d1 --- /dev/null +++ b/models/seg_post_model/cellpose/io.py @@ -0,0 +1,816 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +import os, warnings, glob, shutil +from natsort import natsorted +import numpy as np +import cv2 +import tifffile +import logging, pathlib, sys +from tqdm import tqdm +from pathlib import Path +import re +from .version import version_str +from roifile import ImagejRoi, roiwrite + +try: + from qtpy import QtGui, QtCore, Qt, QtWidgets + from qtpy.QtWidgets import QMessageBox + GUI = True +except: + GUI = False + +try: + import matplotlib.pyplot as plt + MATPLOTLIB = True +except: + MATPLOTLIB = False + +try: + import nd2 + ND2 = True +except: + ND2 = False + +try: + import nrrd + NRRD = True +except: + NRRD = False + +try: + from google.cloud import storage + SERVER_UPLOAD = True +except: + SERVER_UPLOAD = False + +io_logger = logging.getLogger(__name__) + +def logger_setup(cp_path=".cellpose", logfile_name="run.log", stdout_file_replacement=None): + cp_dir = pathlib.Path.home().joinpath(cp_path) + cp_dir.mkdir(exist_ok=True) + log_file = cp_dir.joinpath(logfile_name) + try: + log_file.unlink() + except: + print('creating new log file') + handlers = [logging.FileHandler(log_file),] + if stdout_file_replacement is not None: + handlers.append(logging.FileHandler(stdout_file_replacement)) + else: + handlers.append(logging.StreamHandler(sys.stdout)) + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=handlers, + force=True + ) + logger = logging.getLogger(__name__) + logger.info(f"WRITING LOG OUTPUT TO {log_file}") + logger.info(version_str) + + return logger, log_file + + +from . import utils, plot, transforms + +# helper function to check for a path; if it doesn't exist, make it +def check_dir(path): + if not os.path.isdir(path): + os.mkdir(path) + + +def outlines_to_text(base, outlines): + with open(base + "_cp_outlines.txt", "w") as f: + for o in outlines: + xy = list(o.flatten()) + xy_str = ",".join(map(str, xy)) + f.write(xy_str) + f.write("\n") + + +def load_dax(filename): + ### modified from ZhuangLab github: + ### https://github.com/ZhuangLab/storm-analysis/blob/71ae493cbd17ddb97938d0ae2032d97a0eaa76b2/storm_analysis/sa_library/datareader.py#L156 + + inf_filename = os.path.splitext(filename)[0] + ".inf" + if not os.path.exists(inf_filename): + io_logger.critical( + f"ERROR: no inf file found for dax file {filename}, cannot load dax without it" + ) + return None + + ### get metadata + image_height, image_width = None, None + # extract the movie information from the associated inf file + size_re = re.compile(r"frame dimensions = ([\d]+) x ([\d]+)") + length_re = re.compile(r"number of frames = ([\d]+)") + endian_re = re.compile(r" (big|little) endian") + + with open(inf_filename, "r") as inf_file: + lines = inf_file.read().split("\n") + for line in lines: + m = size_re.match(line) + if m: + image_height = int(m.group(2)) + image_width = int(m.group(1)) + m = length_re.match(line) + if m: + number_frames = int(m.group(1)) + m = endian_re.search(line) + if m: + if m.group(1) == "big": + bigendian = 1 + else: + bigendian = 0 + # set defaults, warn the user that they couldn"t be determined from the inf file. + if not image_height: + io_logger.warning("could not determine dax image size, assuming 256x256") + image_height = 256 + image_width = 256 + + ### load image + img = np.memmap(filename, dtype="uint16", + shape=(number_frames, image_height, image_width)) + if bigendian: + img = img.byteswap() + img = np.array(img) + + return img + + +def imread(filename): + """ + Read in an image file with tif or image file type supported by cv2. + + Args: + filename (str): The path to the image file. + + Returns: + numpy.ndarray: The image data as a NumPy array. + + Raises: + None + + Raises an error if the image file format is not supported. + + Examples: + >>> img = imread("image.tif") + """ + # ensure that extension check is not case sensitive + ext = os.path.splitext(filename)[-1].lower() + if ext == ".tif" or ext == ".tiff" or ext == ".flex": + with tifffile.TiffFile(filename) as tif: + ltif = len(tif.pages) + try: + full_shape = tif.shaped_metadata[0]["shape"] + except: + try: + page = tif.series[0][0] + full_shape = tif.series[0].shape + except: + ltif = 0 + if ltif < 10: + img = tif.asarray() + else: + page = tif.series[0][0] + shape, dtype = page.shape, page.dtype + ltif = int(np.prod(full_shape) / np.prod(shape)) + io_logger.info(f"reading tiff with {ltif} planes") + img = np.zeros((ltif, *shape), dtype=dtype) + for i, page in enumerate(tqdm(tif.series[0])): + img[i] = page.asarray() + img = img.reshape(full_shape) + return img + elif ext == ".dax": + img = load_dax(filename) + return img + elif ext == ".nd2": + if not ND2: + io_logger.critical("ERROR: need to 'pip install nd2' to load in .nd2 file") + return None + elif ext == ".nrrd": + if not NRRD: + io_logger.critical( + "ERROR: need to 'pip install pynrrd' to load in .nrrd file") + return None + else: + img, metadata = nrrd.read(filename) + if img.ndim == 3: + img = img.transpose(2, 0, 1) + return img + elif ext != ".npy": + try: + img = cv2.imread(filename, -1) #cv2.LOAD_IMAGE_ANYDEPTH) + if img.ndim > 2: + img = img[..., [2, 1, 0]] + return img + except Exception as e: + io_logger.critical("ERROR: could not read file, %s" % e) + return None + else: + try: + dat = np.load(filename, allow_pickle=True).item() + masks = dat["masks"] + return masks + except Exception as e: + io_logger.critical("ERROR: could not read masks from file, %s" % e) + return None + + +def imread_2D(img_file): + """ + Read in a 2D image file and convert it to a 3-channel image. Attempts to do this for multi-channel and grayscale images. + If the image has more than 3 channels, only the first 3 channels are kept. + + Args: + img_file (str): The path to the image file. + + Returns: + img_out (numpy.ndarray): The 3-channel image data as a NumPy array. + """ + img = imread(img_file) + return transforms.convert_image(img, do_3D=False) + + +def imread_3D(img_file): + """ + Read in a 3D image file and convert it to have a channel axis last automatically. Attempts to do this for multi-channel and grayscale images. + + If multichannel image, the channel axis is assumed to be the smallest dimension, and the z axis is the next smallest dimension. + Use `cellpose.io.imread()` to load the full image without selecting the z and channel axes. + + Args: + img_file (str): The path to the image file. + + Returns: + img_out (numpy.ndarray): The image data as a NumPy array. + """ + img = imread(img_file) + + dimension_lengths = list(img.shape) + + # grayscale images: + if img.ndim == 3: + channel_axis = None + # guess at z axis: + z_axis = np.argmin(dimension_lengths) + + elif img.ndim == 4: + # guess at channel axis: + channel_axis = np.argmin(dimension_lengths) + + # guess at z axis: + # set channel axis to max so argmin works: + dimension_lengths[channel_axis] = max(dimension_lengths) + z_axis = np.argmin(dimension_lengths) + + else: + raise ValueError(f'image shape error, 3D image must 3 or 4 dimensional. Number of dimensions: {img.ndim}') + + try: + return transforms.convert_image(img, channel_axis=channel_axis, z_axis=z_axis, do_3D=True) + except Exception as e: + io_logger.critical("ERROR: could not read file, %s" % e) + io_logger.critical("ERROR: Guessed z_axis: %s, channel_axis: %s" % (z_axis, channel_axis)) + return None + +def remove_model(filename, delete=False): + """ remove model from .cellpose custom model list """ + filename = os.path.split(filename)[-1] + from . import models + model_strings = models.get_user_models() + if len(model_strings) > 0: + with open(models.MODEL_LIST_PATH, "w") as textfile: + for fname in model_strings: + textfile.write(fname + "\n") + else: + # write empty file + textfile = open(models.MODEL_LIST_PATH, "w") + textfile.close() + print(f"{filename} removed from custom model list") + if delete: + os.remove(os.fspath(models.MODEL_DIR.joinpath(fname))) + print("model deleted") + + +def add_model(filename): + """ add model to .cellpose models folder to use with GUI or CLI """ + from . import models + fname = os.path.split(filename)[-1] + try: + shutil.copyfile(filename, os.fspath(models.MODEL_DIR.joinpath(fname))) + except shutil.SameFileError: + pass + print(f"{filename} copied to models folder {os.fspath(models.MODEL_DIR)}") + if fname not in models.get_user_models(): + with open(models.MODEL_LIST_PATH, "a") as textfile: + textfile.write(fname + "\n") + + +def imsave(filename, arr): + """ + Saves an image array to a file. + + Args: + filename (str): The name of the file to save the image to. + arr (numpy.ndarray): The image array to be saved. + + Returns: + None + """ + ext = os.path.splitext(filename)[-1].lower() + if ext == ".tif" or ext == ".tiff": + tifffile.imwrite(filename, data=arr, compression="zlib") + else: + if len(arr.shape) > 2: + arr = cv2.cvtColor(arr, cv2.COLOR_BGR2RGB) + cv2.imwrite(filename, arr) + + +def get_image_files(folder, mask_filter, imf=None, look_one_level_down=False): + """ + Finds all images in a folder and its subfolders (if specified) with the given file extensions. + + Args: + folder (str): The path to the folder to search for images. + mask_filter (str): The filter for mask files. + imf (str, optional): The additional filter for image files. Defaults to None. + look_one_level_down (bool, optional): Whether to search for images in subfolders. Defaults to False. + + Returns: + list: A list of image file paths. + + Raises: + ValueError: If no files are found in the specified folder. + ValueError: If no images are found in the specified folder with the supported file extensions. + ValueError: If no images are found in the specified folder without the mask or flow file endings. + """ + mask_filters = ["_cp_output", "_flows", "_flows_0", "_flows_1", + "_flows_2", "_cellprob", "_masks", mask_filter] + image_names = [] + if imf is None: + imf = "" + + folders = [] + if look_one_level_down: + folders = natsorted(glob.glob(os.path.join(folder, "*/"))) + folders.append(folder) + exts = [".png", ".jpg", ".jpeg", ".tif", ".tiff", ".flex", ".dax", ".nd2", ".nrrd"] + l0 = 0 + al = 0 + for folder in folders: + all_files = glob.glob(folder + "/*") + al += len(all_files) + for ext in exts: + image_names.extend(glob.glob(folder + f"/*{imf}{ext}")) + image_names.extend(glob.glob(folder + f"/*{imf}{ext.upper()}")) + l0 += len(image_names) + + # return error if no files found + if al == 0: + raise ValueError("ERROR: no files in --dir folder ") + elif l0 == 0: + raise ValueError( + "ERROR: no images in --dir folder with extensions .png, .jpg, .jpeg, .tif, .tiff, .flex" + ) + + image_names = natsorted(image_names) + imn = [] + for im in image_names: + imfile = os.path.splitext(im)[0] + igood = all([(len(imfile) > len(mask_filter) and + imfile[-len(mask_filter):] != mask_filter) or + len(imfile) <= len(mask_filter) for mask_filter in mask_filters]) + if len(imf) > 0: + igood &= imfile[-len(imf):] == imf + if igood: + imn.append(im) + + image_names = imn + + # remove duplicates + image_names = [*set(image_names)] + image_names = natsorted(image_names) + + if len(image_names) == 0: + raise ValueError( + "ERROR: no images in --dir folder without _masks or _flows or _cellprob ending") + + return image_names + +def get_label_files(image_names, mask_filter, imf=None): + """ + Get the label files corresponding to the given image names and mask filter. + + Args: + image_names (list): List of image names. + mask_filter (str): Mask filter to be applied. + imf (str, optional): Image file extension. Defaults to None. + + Returns: + tuple: A tuple containing the label file names and flow file names (if present). + """ + nimg = len(image_names) + label_names0 = [os.path.splitext(image_names[n])[0] for n in range(nimg)] + + if imf is not None and len(imf) > 0: + label_names = [label_names0[n][:-len(imf)] for n in range(nimg)] + else: + label_names = label_names0 + + # check for flows + if os.path.exists(label_names0[0] + "_flows.tif"): + flow_names = [label_names0[n] + "_flows.tif" for n in range(nimg)] + else: + flow_names = [label_names[n] + "_flows.tif" for n in range(nimg)] + if not all([os.path.exists(flow) for flow in flow_names]): + io_logger.info( + "not all flows are present, running flow generation for all images") + flow_names = None + + # check for masks + if mask_filter == "_seg.npy": + label_names = [label_names[n] + mask_filter for n in range(nimg)] + return label_names, None + + if os.path.exists(label_names[0] + mask_filter + ".tif"): + label_names = [label_names[n] + mask_filter + ".tif" for n in range(nimg)] + elif os.path.exists(label_names[0] + mask_filter + ".tiff"): + label_names = [label_names[n] + mask_filter + ".tiff" for n in range(nimg)] + elif os.path.exists(label_names[0] + mask_filter + ".png"): + label_names = [label_names[n] + mask_filter + ".png" for n in range(nimg)] + # TODO, allow _seg.npy + #elif os.path.exists(label_names[0] + "_seg.npy"): + # io_logger.info("labels found as _seg.npy files, converting to tif") + else: + if not flow_names: + raise ValueError("labels not provided with correct --mask_filter") + else: + label_names = None + if not all([os.path.exists(label) for label in label_names]): + if not flow_names: + raise ValueError( + "labels not provided for all images in train and/or test set") + else: + label_names = None + + return label_names, flow_names + + +def load_images_labels(tdir, mask_filter="_masks", image_filter=None, + look_one_level_down=False): + """ + Loads images and corresponding labels from a directory. + + Args: + tdir (str): The directory path. + mask_filter (str, optional): The filter for mask files. Defaults to "_masks". + image_filter (str, optional): The filter for image files. Defaults to None. + look_one_level_down (bool, optional): Whether to look for files one level down. Defaults to False. + + Returns: + tuple: A tuple containing a list of images, a list of labels, and a list of image names. + """ + image_names = get_image_files(tdir, mask_filter, image_filter, look_one_level_down) + nimg = len(image_names) + + # training data + label_names, flow_names = get_label_files(image_names, mask_filter, + imf=image_filter) + + images = [] + labels = [] + k = 0 + for n in range(nimg): + if (os.path.isfile(label_names[n]) or + (flow_names is not None and os.path.isfile(flow_names[0]))): + image = imread(image_names[n]) + if label_names is not None: + label = imread(label_names[n]) + if flow_names is not None: + flow = imread(flow_names[n]) + if flow.shape[0] < 4: + label = np.concatenate((label[np.newaxis, :, :], flow), axis=0) + else: + label = flow + images.append(image) + labels.append(label) + k += 1 + io_logger.info(f"{k} / {nimg} images in {tdir} folder have labels") + return images, labels, image_names + +def load_train_test_data(train_dir, test_dir=None, image_filter=None, + mask_filter="_masks", look_one_level_down=False): + """ + Loads training and testing data for a Cellpose model. + + Args: + train_dir (str): The directory path containing the training data. + test_dir (str, optional): The directory path containing the testing data. Defaults to None. + image_filter (str, optional): The filter for selecting image files. Defaults to None. + mask_filter (str, optional): The filter for selecting mask files. Defaults to "_masks". + look_one_level_down (bool, optional): Whether to look for data in subdirectories of train_dir and test_dir. Defaults to False. + + Returns: + images, labels, image_names, test_images, test_labels, test_image_names + + """ + images, labels, image_names = load_images_labels(train_dir, mask_filter, + image_filter, look_one_level_down) + # testing data + test_images, test_labels, test_image_names = None, None, None + if test_dir is not None: + test_images, test_labels, test_image_names = load_images_labels( + test_dir, mask_filter, image_filter, look_one_level_down) + + return images, labels, image_names, test_images, test_labels, test_image_names + + +def masks_flows_to_seg(images, masks, flows, file_names, + channels=None, + imgs_restore=None, restore_type=None, ratio=1.): + """Save output of model eval to be loaded in GUI. + + Can be list output (run on multiple images) or single output (run on single image). + + Saved to file_names[k]+"_seg.npy". + + Args: + images (list): Images input into cellpose. + masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels. + flows (list): Flows output from Cellpose.eval. + file_names (list, str): Names of files of images. + diams (float array): Diameters used to run Cellpose. Defaults to 30. TODO: remove this + channels (list, int, optional): Channels used to run Cellpose. Defaults to None. + + Returns: + None + """ + + if channels is None: + channels = [0, 0] + + if isinstance(masks, list): + if imgs_restore is None: + imgs_restore = [None] * len(masks) + if isinstance(file_names, str): + file_names = [file_names] * len(masks) + for k, [image, mask, flow, + # diam, + file_name, img_restore + ] in enumerate(zip(images, masks, flows, + # diams, + file_names, + imgs_restore)): + channels_img = channels + if channels_img is not None and len(channels) > 2: + channels_img = channels[k] + masks_flows_to_seg(image, mask, flow, file_name, + # diams=diam, + channels=channels_img, imgs_restore=img_restore, + restore_type=restore_type, ratio=ratio) + return + + if len(channels) == 1: + channels = channels[0] + + flowi = [] + if flows[0].ndim == 3: + Ly, Lx = masks.shape[-2:] + flowi.append( + cv2.resize(flows[0], (Lx, Ly), interpolation=cv2.INTER_NEAREST)[np.newaxis, + ...]) + else: + flowi.append(flows[0]) + + if flows[0].ndim == 3: + cellprob = (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype( + np.uint8) + cellprob = cv2.resize(cellprob, (Lx, Ly), interpolation=cv2.INTER_NEAREST) + flowi.append(cellprob[np.newaxis, ...]) + flowi.append(np.zeros(flows[0].shape, dtype=np.uint8)) + flowi[-1] = flowi[-1][np.newaxis, ...] + else: + flowi.append( + (np.clip(transforms.normalize99(flows[2]), 0, 1) * 255).astype(np.uint8)) + flowi.append((flows[1][0] / 10 * 127 + 127).astype(np.uint8)) + if len(flows) > 2: + if len(flows) > 3: + flowi.append(flows[3]) + else: + flowi.append([]) + flowi.append(np.concatenate((flows[1], flows[2][np.newaxis, ...]), axis=0)) + outlines = masks * utils.masks_to_outlines(masks) + base = os.path.splitext(file_names)[0] + + dat = { + "outlines": + outlines.astype(np.uint16) if outlines.max() < 2**16 - + 1 else outlines.astype(np.uint32), + "masks": + masks.astype(np.uint16) if outlines.max() < 2**16 - + 1 else masks.astype(np.uint32), + "chan_choose": + channels, + "ismanual": + np.zeros(masks.max(), bool), + "filename": + file_names, + "flows": + flowi, + "diameter": + np.nan + } + if restore_type is not None and imgs_restore is not None: + dat["restore"] = restore_type + dat["ratio"] = ratio + dat["img_restore"] = imgs_restore + + np.save(base + "_seg.npy", dat) + +def save_to_png(images, masks, flows, file_names): + """ deprecated (runs io.save_masks with png=True) + + does not work for 3D images + + """ + save_masks(images, masks, flows, file_names, png=True) + + +def save_rois(masks, file_name, multiprocessing=None): + """ save masks to .roi files in .zip archive for ImageJ/Fiji + + Args: + masks (np.ndarray): masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels + file_name (str): name to save the .zip file to + + Returns: + None + """ + outlines = utils.outlines_list(masks, multiprocessing=multiprocessing) + nonempty_outlines = [outline for outline in outlines if len(outline)!=0] + if len(outlines)!=len(nonempty_outlines): + print(f"empty outlines found, saving {len(nonempty_outlines)} ImageJ ROIs to .zip archive.") + rois = [ImagejRoi.frompoints(outline) for outline in nonempty_outlines] + file_name = os.path.splitext(file_name)[0] + '_rois.zip' + + + # Delete file if it exists; the roifile lib appends to existing zip files. + # If the user removed a mask it will still be in the zip file + if os.path.exists(file_name): + os.remove(file_name) + + roiwrite(file_name, rois) + + +def save_masks(images, masks, flows, file_names, png=True, tif=False, channels=[0, 0], + suffix="_cp_masks", save_flows=False, save_outlines=False, dir_above=False, + in_folders=False, savedir=None, save_txt=False, save_mpl=False): + """ Save masks + nicely plotted segmentation image to png and/or tiff. + + Can save masks, flows to different directories, if in_folders is True. + + If png, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.png". + + If tif, masks[k] for images[k] are saved to file_names[k]+"_cp_masks.tif". + + If png and matplotlib installed, full segmentation figure is saved to file_names[k]+"_cp.png". + + Only tif option works for 3D data, and only tif option works for empty masks. + + Args: + images (list): Images input into cellpose. + masks (list): Masks output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels. + flows (list): Flows output from Cellpose.eval. + file_names (list, str): Names of files of images. + png (bool, optional): Save masks to PNG. Defaults to True. + tif (bool, optional): Save masks to TIF. Defaults to False. + channels (list, int, optional): Channels used to run Cellpose. Defaults to [0,0]. + suffix (str, optional): Add name to saved masks. Defaults to "_cp_masks". + save_flows (bool, optional): Save flows output from Cellpose.eval. Defaults to False. + save_outlines (bool, optional): Save outlines of masks. Defaults to False. + dir_above (bool, optional): Save masks/flows in directory above. Defaults to False. + in_folders (bool, optional): Save masks/flows in separate folders. Defaults to False. + savedir (str, optional): Absolute path where images will be saved. If None, saves to image directory. Defaults to None. + save_txt (bool, optional): Save masks as list of outlines for ImageJ. Defaults to False. + save_mpl (bool, optional): If True, saves a matplotlib figure of the original image/segmentation/flows. Does not work for 3D. + This takes a long time for large images. Defaults to False. + + Returns: + None + """ + + if isinstance(masks, list): + for image, mask, flow, file_name in zip(images, masks, flows, file_names): + save_masks(image, mask, flow, file_name, png=png, tif=tif, suffix=suffix, + dir_above=dir_above, save_flows=save_flows, + save_outlines=save_outlines, savedir=savedir, save_txt=save_txt, + in_folders=in_folders, save_mpl=save_mpl) + return + + if masks.ndim > 2 and not tif: + raise ValueError("cannot save 3D outputs as PNG, use tif option instead") + + if masks.max() == 0: + io_logger.warning("no masks found, will not save PNG or outlines") + if not tif: + return + else: + png = False + save_outlines = False + save_flows = False + save_txt = False + + if savedir is None: + if dir_above: + savedir = Path(file_names).parent.parent.absolute( + ) #go up a level to save in its own folder + else: + savedir = Path(file_names).parent.absolute() + + check_dir(savedir) + + basename = os.path.splitext(os.path.basename(file_names))[0] + if in_folders: + maskdir = os.path.join(savedir, "masks") + outlinedir = os.path.join(savedir, "outlines") + txtdir = os.path.join(savedir, "txt_outlines") + flowdir = os.path.join(savedir, "flows") + else: + maskdir = savedir + outlinedir = savedir + txtdir = savedir + flowdir = savedir + + check_dir(maskdir) + + exts = [] + if masks.ndim > 2: + png = False + tif = True + if png: + if masks.max() < 2**16: + masks = masks.astype(np.uint16) + exts.append(".png") + else: + png = False + tif = True + io_logger.warning( + "found more than 65535 masks in each image, cannot save PNG, saving as TIF" + ) + if tif: + exts.append(".tif") + + # save masks + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + for ext in exts: + imsave(os.path.join(maskdir, basename + suffix + ext), masks) + + if save_mpl and png and MATPLOTLIB and not min(images.shape) > 3: + # Make and save original/segmentation/flows image + + img = images.copy() + if img.ndim < 3: + img = img[:, :, np.newaxis] + elif img.shape[0] < 8: + np.transpose(img, (1, 2, 0)) + + fig = plt.figure(figsize=(12, 3)) + plot.show_segmentation(fig, img, masks, flows[0]) + fig.savefig(os.path.join(savedir, basename + "_cp_output" + suffix + ".png"), + dpi=300) + plt.close(fig) + + # ImageJ txt outline files + if masks.ndim < 3 and save_txt: + check_dir(txtdir) + outlines = utils.outlines_list(masks) + outlines_to_text(os.path.join(txtdir, basename), outlines) + + # RGB outline images + if masks.ndim < 3 and save_outlines: + check_dir(outlinedir) + outlines = utils.masks_to_outlines(masks) + outX, outY = np.nonzero(outlines) + img0 = transforms.normalize99(images) + if img0.shape[0] < 4: + img0 = np.transpose(img0, (1, 2, 0)) + if img0.shape[-1] < 3 or img0.ndim < 3: + img0 = plot.image_to_rgb(img0, channels=channels) + else: + if img0.max() <= 50.0: + img0 = np.uint8(np.clip(img0 * 255, 0, 1)) + imgout = img0.copy() + imgout[outX, outY] = np.array([255, 0, 0]) #pure red + imsave(os.path.join(outlinedir, basename + "_outlines" + suffix + ".png"), + imgout) + + # save RGB flow picture + if masks.ndim < 3 and save_flows: + check_dir(flowdir) + imsave(os.path.join(flowdir, basename + "_flows" + suffix + ".tif"), + (flows[0] * (2**16 - 1)).astype(np.uint16)) + #save full flow data + imsave(os.path.join(flowdir, basename + '_dP' + suffix + '.tif'), flows[1]) diff --git a/models/seg_post_model/cellpose/metrics.py b/models/seg_post_model/cellpose/metrics.py new file mode 100644 index 0000000000000000000000000000000000000000..6e28005959ed74bc2dc5b4a726787b9cfc2387b5 --- /dev/null +++ b/models/seg_post_model/cellpose/metrics.py @@ -0,0 +1,205 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +import numpy as np +from . import utils +from scipy.optimize import linear_sum_assignment +from scipy.ndimage import convolve +from scipy.sparse import csr_matrix + + +def mask_ious(masks_true, masks_pred): + """Return best-matched masks.""" + iou = _intersection_over_union(masks_true, masks_pred)[1:, 1:] + n_min = min(iou.shape[0], iou.shape[1]) + costs = -(iou >= 0.5).astype(float) - iou / (2 * n_min) + true_ind, pred_ind = linear_sum_assignment(costs) + iout = np.zeros(masks_true.max()) + iout[true_ind] = iou[true_ind, pred_ind] + preds = np.zeros(masks_true.max(), "int") + preds[true_ind] = pred_ind + 1 + return iout, preds + + +def boundary_scores(masks_true, masks_pred, scales): + """ + Calculate boundary precision, recall, and F-score. + + Args: + masks_true (list): List of true masks. + masks_pred (list): List of predicted masks. + scales (list): List of scales. + + Returns: + tuple: A tuple containing precision, recall, and F-score arrays. + """ + diams = [utils.diameters(lbl)[0] for lbl in masks_true] + precision = np.zeros((len(scales), len(masks_true))) + recall = np.zeros((len(scales), len(masks_true))) + fscore = np.zeros((len(scales), len(masks_true))) + for j, scale in enumerate(scales): + for n in range(len(masks_true)): + diam = max(1, scale * diams[n]) + rs, ys, xs = utils.circleMask([int(np.ceil(diam)), int(np.ceil(diam))]) + filt = (rs <= diam).astype(np.float32) + otrue = utils.masks_to_outlines(masks_true[n]) + otrue = convolve(otrue, filt) + opred = utils.masks_to_outlines(masks_pred[n]) + opred = convolve(opred, filt) + tp = np.logical_and(otrue == 1, opred == 1).sum() + fp = np.logical_and(otrue == 0, opred == 1).sum() + fn = np.logical_and(otrue == 1, opred == 0).sum() + precision[j, n] = tp / (tp + fp) + recall[j, n] = tp / (tp + fn) + fscore[j] = 2 * precision[j] * recall[j] / (precision[j] + recall[j]) + return precision, recall, fscore + + +def aggregated_jaccard_index(masks_true, masks_pred): + """ + AJI = intersection of all matched masks / union of all masks + + Args: + masks_true (list of np.ndarrays (int) or np.ndarray (int)): + where 0=NO masks; 1,2... are mask labels + masks_pred (list of np.ndarrays (int) or np.ndarray (int)): + np.ndarray (int) where 0=NO masks; 1,2... are mask labels + + Returns: + aji (float): aggregated jaccard index for each set of masks + """ + aji = np.zeros(len(masks_true)) + for n in range(len(masks_true)): + iout, preds = mask_ious(masks_true[n], masks_pred[n]) + inds = np.arange(0, masks_true[n].max(), 1, int) + overlap = _label_overlap(masks_true[n], masks_pred[n]) + union = np.logical_or(masks_true[n] > 0, masks_pred[n] > 0).sum() + overlap = overlap[inds[preds > 0] + 1, preds[preds > 0].astype(int)] + aji[n] = overlap.sum() / union + return aji + + +def average_precision(masks_true, masks_pred, threshold=[0.5, 0.75, 0.9]): + """ + Average precision estimation: AP = TP / (TP + FP + FN) + + This function is based heavily on the *fast* stardist matching functions + (https://github.com/mpicbg-csbd/stardist/blob/master/stardist/matching.py) + + Args: + masks_true (list of np.ndarrays (int) or np.ndarray (int)): + where 0=NO masks; 1,2... are mask labels + masks_pred (list of np.ndarrays (int) or np.ndarray (int)): + np.ndarray (int) where 0=NO masks; 1,2... are mask labels + + Returns: + ap (array [len(masks_true) x len(threshold)]): + average precision at thresholds + tp (array [len(masks_true) x len(threshold)]): + number of true positives at thresholds + fp (array [len(masks_true) x len(threshold)]): + number of false positives at thresholds + fn (array [len(masks_true) x len(threshold)]): + number of false negatives at thresholds + """ + not_list = False + if not isinstance(masks_true, list): + masks_true = [masks_true] + masks_pred = [masks_pred] + not_list = True + if not isinstance(threshold, list) and not isinstance(threshold, np.ndarray): + threshold = [threshold] + + if len(masks_true) != len(masks_pred): + raise ValueError( + "metrics.average_precision requires len(masks_true)==len(masks_pred)") + + ap = np.zeros((len(masks_true), len(threshold)), np.float32) + tp = np.zeros((len(masks_true), len(threshold)), np.float32) + fp = np.zeros((len(masks_true), len(threshold)), np.float32) + fn = np.zeros((len(masks_true), len(threshold)), np.float32) + n_true = np.array([len(np.unique(mt)) - 1 for mt in masks_true]) + n_pred = np.array([len(np.unique(mp)) - 1 for mp in masks_pred]) + + for n in range(len(masks_true)): + #_,mt = np.reshape(np.unique(masks_true[n], return_index=True), masks_pred[n].shape) + if n_pred[n] > 0: + iou = _intersection_over_union(masks_true[n], masks_pred[n])[1:, 1:] + for k, th in enumerate(threshold): + tp[n, k] = _true_positive(iou, th) + fp[n] = n_pred[n] - tp[n] + fn[n] = n_true[n] - tp[n] + ap[n] = tp[n] / (tp[n] + fp[n] + fn[n]) + + if not_list: + ap, tp, fp, fn = ap[0], tp[0], fp[0], fn[0] + return ap, tp, fp, fn + + +def _intersection_over_union(masks_true, masks_pred): + """Calculate the intersection over union of all mask pairs. + + Parameters: + masks_true (np.ndarray, int): Ground truth masks, where 0=NO masks; 1,2... are mask labels. + masks_pred (np.ndarray, int): Predicted masks, where 0=NO masks; 1,2... are mask labels. + + Returns: + iou (np.ndarray, float): Matrix of IOU pairs of size [x.max()+1, y.max()+1]. + + How it works: + The overlap matrix is a lookup table of the area of intersection + between each set of labels (true and predicted). The true labels + are taken to be along axis 0, and the predicted labels are taken + to be along axis 1. The sum of the overlaps along axis 0 is thus + an array giving the total overlap of the true labels with each of + the predicted labels, and likewise the sum over axis 1 is the + total overlap of the predicted labels with each of the true labels. + Because the label 0 (background) is included, this sum is guaranteed + to reconstruct the total area of each label. Adding this row and + column vectors gives a 2D array with the areas of every label pair + added together. This is equivalent to the union of the label areas + except for the duplicated overlap area, so the overlap matrix is + subtracted to find the union matrix. + """ + if masks_true.size != masks_pred.size: + raise ValueError(f"masks_true.size {masks_true.shape} != masks_pred.size {masks_pred.shape}") + overlap = csr_matrix((np.ones((masks_true.size,), "int"), + (masks_true.flatten(), masks_pred.flatten())), + shape=(masks_true.max()+1, masks_pred.max()+1)) + overlap = overlap.toarray() + n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) + n_pixels_true = np.sum(overlap, axis=1, keepdims=True) + iou = overlap / (n_pixels_pred + n_pixels_true - overlap) + iou[np.isnan(iou)] = 0.0 + return iou + + +def _true_positive(iou, th): + """Calculate the true positive at threshold th. + + Args: + iou (float, np.ndarray): Array of IOU pairs. + th (float): Threshold on IOU for positive label. + + Returns: + tp (float): Number of true positives at threshold. + + How it works: + (1) Find minimum number of masks. + (2) Define cost matrix; for a given threshold, each element is negative + the higher the IoU is (perfect IoU is 1, worst is 0). The second term + gets more negative with higher IoU, but less negative with greater + n_min (but that's a constant...). + (3) Solve the linear sum assignment problem. The costs array defines the cost + of matching a true label with a predicted label, so the problem is to + find the set of pairings that minimizes this cost. The scipy.optimize + function gives the ordered lists of corresponding true and predicted labels. + (4) Extract the IoUs from these pairings and then threshold to get a boolean array + whose sum is the number of true positives that is returned. + """ + n_min = min(iou.shape[0], iou.shape[1]) + costs = -(iou >= th).astype(float) - iou / (2 * n_min) + true_ind, pred_ind = linear_sum_assignment(costs) + match_ok = iou[true_ind, pred_ind] >= th + tp = match_ok.sum() + return tp diff --git a/models/seg_post_model/cellpose/models.py b/models/seg_post_model/cellpose/models.py new file mode 100644 index 0000000000000000000000000000000000000000..f54a52b9958e63da11c88c2f82c1ab766e2ae024 --- /dev/null +++ b/models/seg_post_model/cellpose/models.py @@ -0,0 +1,524 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer, Michael Rariden and Marius Pachitariu. +""" + +import os, time +from pathlib import Path +import numpy as np +from tqdm import trange +import torch +from scipy.ndimage import gaussian_filter +import gc +import cv2 + +import logging + +models_logger = logging.getLogger(__name__) + +from . import transforms, dynamics, utils, plot +from .vit_sam import Transformer +from .core import assign_device, run_net, run_3D + +_CPSAM_MODEL_URL = "https://huggingface.co/mouseland/cellpose-sam/resolve/main/cpsam" +_MODEL_DIR_ENV = os.environ.get("CELLPOSE_LOCAL_MODELS_PATH") +# _MODEL_DIR_DEFAULT = Path.home().joinpath(".cellpose", "models") +_MODEL_DIR_DEFAULT = Path("/media/data1/huix/seg/cellpose_models") +MODEL_DIR = Path(_MODEL_DIR_ENV) if _MODEL_DIR_ENV else _MODEL_DIR_DEFAULT + +MODEL_NAMES = ["cpsam"] + +MODEL_LIST_PATH = os.fspath(MODEL_DIR.joinpath("gui_models.txt")) + +normalize_default = { + "lowhigh": None, + "percentile": None, + "normalize": True, + "norm3D": True, + "sharpen_radius": 0, + "smooth_radius": 0, + "tile_norm_blocksize": 0, + "tile_norm_smooth3D": 1, + "invert": False +} + + +# def model_path(model_type, model_index=0): +# return cache_CPSAM_model_path() + + +# def cache_CPSAM_model_path(): +# MODEL_DIR.mkdir(parents=True, exist_ok=True) +# cached_file = os.fspath(MODEL_DIR.joinpath('cpsam')) +# if not os.path.exists(cached_file): +# models_logger.info('Downloading: "{}" to {}\n'.format(_CPSAM_MODEL_URL, cached_file)) +# utils.download_url_to_file(_CPSAM_MODEL_URL, cached_file, progress=True) +# return cached_file + + +def get_user_models(): + model_strings = [] + if os.path.exists(MODEL_LIST_PATH): + with open(MODEL_LIST_PATH, "r") as textfile: + lines = [line.rstrip() for line in textfile] + if len(lines) > 0: + model_strings.extend(lines) + return model_strings + + +class CellposeModel(): + """ + Class representing a Cellpose model. + + Attributes: + diam_mean (float): Mean "diameter" value for the model. + builtin (bool): Whether the model is a built-in model or not. + device (torch device): Device used for model running / training. + nclasses (int): Number of classes in the model. + nbase (list): List of base values for the model. + net (CPnet): Cellpose network. + pretrained_model (str): Path to pretrained cellpose model. + pretrained_model_ortho (str): Path or model_name for pretrained cellpose model for ortho views in 3D. + backbone (str): Type of network ("default" is the standard res-unet, "transformer" for the segformer). + + Methods: + __init__(self, gpu=False, pretrained_model=False, model_type=None, diam_mean=30., device=None): + Initialize the CellposeModel. + + eval(self, x, batch_size=8, resample=True, channels=None, channel_axis=None, z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0, min_size=15, niter=None, augment=False, tile_overlap=0.1, bsize=224, interp=True, compute_masks=True, progress=None): + Segment list of images x, or 4D array - Z x C x Y x X. + + """ + + def __init__(self, gpu=False, pretrained_model="", model_type=None, + diam_mean=None, device=None, nchan=None, use_bfloat16=True, vit_checkpoint=None): + """ + Initialize the CellposeModel. + + Parameters: + gpu (bool, optional): Whether or not to save model to GPU, will check if GPU available. + pretrained_model (str or list of strings, optional): Full path to pretrained cellpose model(s), if None or False, no model loaded. + model_type (str, optional): Any model that is available in the GUI, use name in GUI e.g. "livecell" (can be user-trained or model zoo). + diam_mean (float, optional): Mean "diameter", 30. is built-in value for "cyto" model; 17. is built-in value for "nuclei" model; if saved in custom model file (cellpose>=2.0) then it will be loaded automatically and overwrite this value. + device (torch device, optional): Device used for model running / training (torch.device("cuda") or torch.device("cpu")), overrides gpu input, recommended if you want to use a specific GPU (e.g. torch.device("cuda:1")). + use_bfloat16 (bool, optional): Use 16bit float precision instead of 32bit for model weights. Default to 16bit (True). + """ + # if diam_mean is not None: + # models_logger.warning( + # "diam_mean argument are not used in v4.0.1+. Ignoring this argument..." + # ) + # if model_type is not None: + # models_logger.warning( + # "model_type argument is not used in v4.0.1+. Ignoring this argument..." + # ) + # if nchan is not None: + # models_logger.warning("nchan argument is deprecated in v4.0.1+. Ignoring this argument") + + ### assign model device + self.device = assign_device(gpu=gpu)[0] if device is None else device + if torch.cuda.is_available(): + device_gpu = self.device.type == "cuda" + elif torch.backends.mps.is_available(): + device_gpu = self.device.type == "mps" + else: + device_gpu = False + self.gpu = device_gpu + + if pretrained_model is None: + # raise ValueError("Must specify a pretrained model, training from scratch is not implemented") + pretrained_model = "" + + ### create neural network + if pretrained_model and not os.path.exists(pretrained_model): + # check if pretrained model is in the models directory + model_strings = get_user_models() + all_models = MODEL_NAMES.copy() + all_models.extend(model_strings) + if pretrained_model in all_models: + pretrained_model = os.path.join(MODEL_DIR, pretrained_model) + else: + pretrained_model = os.path.join(MODEL_DIR, "cpsam") + models_logger.warning( + f"pretrained model {pretrained_model} not found, using default model" + ) + + self.pretrained_model = pretrained_model + dtype = torch.bfloat16 if use_bfloat16 else torch.float32 + self.net = Transformer(dtype=dtype, checkpoint=vit_checkpoint).to(self.device) + + if os.path.exists(self.pretrained_model): + models_logger.info(f">>>> loading model {self.pretrained_model}") + self.net.load_model(self.pretrained_model, device=self.device) + # else: + # try: + # if os.path.split(self.pretrained_model)[-1] != 'cpsam': + # raise FileNotFoundError('model file not recognized') + # cache_CPSAM_model_path() + # self.net.load_model(self.pretrained_model, device=self.device) + # except: + # print("ViT not initialized") + + + def eval(self, x, feat=None, batch_size=8, resample=True, channels=None, channel_axis=None, + z_axis=None, normalize=True, invert=False, rescale=None, diameter=None, + flow_threshold=0.4, cellprob_threshold=0.0, do_3D=False, anisotropy=None, + flow3D_smooth=0, stitch_threshold=0.0, + min_size=15, max_size_fraction=0.4, niter=None, + augment=False, tile_overlap=0.1, bsize=256, + compute_masks=True, progress=None): + + + # if rescale is not None: + # models_logger.warning("rescaling deprecated in v4.0.1+") + # if channels is not None: + # models_logger.warning("channels deprecated in v4.0.1+. If data contain more than 3 channels, only the first 3 channels will be used") + + if isinstance(x, list) or x.squeeze().ndim == 5: + self.timing = [] + masks, styles, flows = [], [], [] + tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) + nimg = len(x) + iterator = trange(nimg, file=tqdm_out, + mininterval=30) if nimg > 1 else range(nimg) + for i in iterator: + tic = time.time() + maski, flowi, stylei = self.eval( + x[i], + feat=None if feat is None else feat[i], + batch_size=batch_size, + channel_axis=channel_axis, + z_axis=z_axis, + normalize=normalize, + invert=invert, + diameter=diameter[i] if isinstance(diameter, list) or + isinstance(diameter, np.ndarray) else diameter, + do_3D=do_3D, + anisotropy=anisotropy, + augment=augment, + tile_overlap=tile_overlap, + bsize=bsize, + resample=resample, + flow_threshold=flow_threshold, + cellprob_threshold=cellprob_threshold, + compute_masks=compute_masks, + min_size=min_size, + max_size_fraction=max_size_fraction, + stitch_threshold=stitch_threshold, + flow3D_smooth=flow3D_smooth, + progress=progress, + niter=niter) + masks.append(maski) + flows.append(flowi) + styles.append(stylei) + self.timing.append(time.time() - tic) + return masks, flows, styles + + ############# actual eval code ############ + # reshape image + x = transforms.convert_image(x, channel_axis=channel_axis, + z_axis=z_axis, + do_3D=(do_3D or stitch_threshold > 0)) + + # Add batch dimension if not present + if x.ndim < 4: + x = x[np.newaxis, ...] + if feat is not None: + if feat.ndim < 4: + feat = feat[np.newaxis, ...] + nimg = x.shape[0] + + image_scaling = None + Ly_0 = x.shape[1] + Lx_0 = x.shape[2] + Lz_0 = None + if do_3D or stitch_threshold > 0: + Lz_0 = x.shape[0] + if diameter is not None: + image_scaling = 30. / diameter + x = transforms.resize_image(x, + Ly=int(x.shape[1] * image_scaling), + Lx=int(x.shape[2] * image_scaling)) + if feat is not None: + feat = transforms.resize_image(feat, + Ly=int(feat.shape[1] * image_scaling), + Lx=int(feat.shape[2] * image_scaling)) + + + # normalize image + normalize_params = normalize_default + if isinstance(normalize, dict): + normalize_params = {**normalize_params, **normalize} + elif not isinstance(normalize, bool): + raise ValueError("normalize parameter must be a bool or a dict") + else: + normalize_params["normalize"] = normalize + normalize_params["invert"] = invert + + # pre-normalize if 3D stack for stitching or do_3D + do_normalization = True if normalize_params["normalize"] else False + if nimg > 1 and do_normalization and (stitch_threshold or do_3D): + normalize_params["norm3D"] = True if do_3D else normalize_params["norm3D"] + x = transforms.normalize_img(x, **normalize_params) + do_normalization = False # do not normalize again + else: + if normalize_params["norm3D"] and nimg > 1 and do_normalization: + models_logger.warning( + "normalize_params['norm3D'] is True but do_3D is False and stitch_threshold=0, so setting to False" + ) + normalize_params["norm3D"] = False + if do_normalization: + x = transforms.normalize_img(x, **normalize_params) + + if feat is not None: + if feat.shape[-1] > feat.shape[1]: + # transpose feat to have channels last + feat = np.moveaxis(feat, 1, -1) + + # ajust the anisotropy when diameter is specified and images are resized: + if isinstance(anisotropy, (float, int)) and image_scaling: + anisotropy = image_scaling * anisotropy + + dP, cellprob, styles = self._run_net( + x, + feat=feat, + augment=augment, + batch_size=batch_size, + tile_overlap=tile_overlap, + bsize=bsize, + do_3D=do_3D, + anisotropy=anisotropy) + + if do_3D: + if flow3D_smooth > 0: + models_logger.info(f"smoothing flows with sigma={flow3D_smooth}") + dP = gaussian_filter(dP, (0, flow3D_smooth, flow3D_smooth, flow3D_smooth)) + torch.cuda.empty_cache() + gc.collect() + + if resample: + # upsample flows before computing them: + dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0) + cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0) + + + if compute_masks: + niter0 = 200 + niter = niter0 if niter is None or niter == 0 else niter + masks = self._compute_masks(x.shape, dP, cellprob, flow_threshold=flow_threshold, + cellprob_threshold=cellprob_threshold, min_size=min_size, + max_size_fraction=max_size_fraction, niter=niter, + stitch_threshold=stitch_threshold, do_3D=do_3D) + else: + masks = np.zeros(0) #pass back zeros if not compute_masks + + masks, dP, cellprob = masks.squeeze(), dP.squeeze(), cellprob.squeeze() + + # undo resizing: + if image_scaling is not None or anisotropy is not None: + + dP = self._resize_gradients(dP, to_y_size=Ly_0, to_x_size=Lx_0, to_z_size=Lz_0) # works for 2 or 3D: + cellprob = self._resize_cellprob(cellprob, to_x_size=Lx_0, to_y_size=Ly_0, to_z_size=Lz_0) + + if do_3D: + if compute_masks: + # Rescale xy then xz: + masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST) + masks = masks.transpose(1, 0, 2) + masks = transforms.resize_image(masks, Ly=Lz_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST) + masks = masks.transpose(1, 0, 2) + + else: + # 2D or 3D stitching case: + if compute_masks: + masks = transforms.resize_image(masks, Ly=Ly_0, Lx=Lx_0, no_channels=True, interpolation=cv2.INTER_NEAREST) + + return masks, [plot.dx_to_circ(dP), dP, cellprob], styles + + + def _resize_cellprob(self, prob: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray: + """ + Resize cellprob array to specified dimensions for either 2D or 3D. + + Parameters: + prob (numpy.ndarray): The cellprobs to resize, either in 2D or 3D. Returns the same ndim as provided. + to_y_size (int): The target size along the Y-axis. + to_x_size (int): The target size along the X-axis. + to_z_size (int, optional): The target size along the Z-axis. Required + for 3D cellprobs. + + Returns: + numpy.ndarray: The resized cellprobs array with the same number of dimensions + as the input. + + Raises: + ValueError: If the input cellprobs array does not have 3 or 4 dimensions. + """ + prob_shape = prob.shape + prob = prob.squeeze() + squeeze_happened = prob.shape != prob_shape + prob_shape = np.array(prob_shape) + + if prob.ndim == 2: + # 2D case: + prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True) + if squeeze_happened: + prob = np.expand_dims(prob, int(np.argwhere(prob_shape == 1))) # add back empty axis for compatibility + elif prob.ndim == 3: + # 3D case: + prob = transforms.resize_image(prob, Ly=to_y_size, Lx=to_x_size, no_channels=True) + prob = prob.transpose(1, 0, 2) + prob = transforms.resize_image(prob, Ly=to_z_size, Lx=to_x_size, no_channels=True) + prob = prob.transpose(1, 0, 2) + else: + raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 2 or 3, prob shape: {prob.shape}') + + return prob + + + def _resize_gradients(self, grads: np.ndarray, to_y_size: int, to_x_size: int, to_z_size: int = None) -> np.ndarray: + """ + Resize gradient arrays to specified dimensions for either 2D or 3D gradients. + + Parameters: + grads (np.ndarray): The gradients to resize, either in 2D or 3D. Returns the same ndim as provided. + to_y_size (int): The target size along the Y-axis. + to_x_size (int): The target size along the X-axis. + to_z_size (int, optional): The target size along the Z-axis. Required + for 3D gradients. + + Returns: + numpy.ndarray: The resized gradient array with the same number of dimensions + as the input. + + Raises: + ValueError: If the input gradient array does not have 3 or 4 dimensions. + """ + grads_shape = grads.shape + grads = grads.squeeze() + squeeze_happened = grads.shape != grads_shape + grads_shape = np.array(grads_shape) + + if grads.ndim == 3: + # 2D case, with XY flows in 2 channels: + grads = np.moveaxis(grads, 0, -1) # Put gradients last + grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False) + grads = np.moveaxis(grads, -1, 0) # Put gradients first + + if squeeze_happened: + grads = np.expand_dims(grads, int(np.argwhere(grads_shape == 1))) # add back empty axis for compatibility + elif grads.ndim == 4: + # dP has gradients that can be treated as channels: + grads = grads.transpose(1, 2, 3, 0) # move gradients last: + grads = transforms.resize_image(grads, Ly=to_y_size, Lx=to_x_size, no_channels=False) + grads = grads.transpose(1, 0, 2, 3) # switch axes to resize again + grads = transforms.resize_image(grads, Ly=to_z_size, Lx=to_x_size, no_channels=False) + grads = grads.transpose(3, 1, 0, 2) # undo transposition + else: + raise ValueError(f'gradients have incorrect dimension after squeezing. Should be 3 or 4, grads shape: {grads.shape}') + + return grads + + + def _run_net(self, x, feat=None, + augment=False, + batch_size=8, tile_overlap=0.1, + bsize=224, anisotropy=1.0, do_3D=False): + """ run network on image x """ + tic = time.time() + shape = x.shape + nimg = shape[0] + + + if do_3D: + Lz, Ly, Lx = shape[:-1] + if anisotropy is not None and anisotropy != 1.0: + models_logger.info(f"resizing 3D image with anisotropy={anisotropy}") + x = transforms.resize_image(x.transpose(1,0,2,3), + Ly=int(Lz*anisotropy), + Lx=int(Lx)).transpose(1,0,2,3) + yf, styles = run_3D(self.net, x, + batch_size=batch_size, augment=augment, + tile_overlap=tile_overlap, + bsize=bsize + ) + cellprob = yf[..., -1] + dP = yf[..., :-1].transpose((3, 0, 1, 2)) + else: + yf, styles = run_net(self.net, x, feat=feat, bsize=bsize, augment=augment, + batch_size=batch_size, + tile_overlap=tile_overlap, + ) + cellprob = yf[..., -1] + dP = yf[..., -3:-1].transpose((3, 0, 1, 2)) + if yf.shape[-1] > 3: + styles = yf[..., :-3] + + styles = styles.squeeze() + + net_time = time.time() - tic + if nimg > 1: + models_logger.info("network run in %2.2fs" % (net_time)) + + return dP, cellprob, styles + + def _compute_masks(self, shape, dP, cellprob, flow_threshold=0.4, cellprob_threshold=0.0, + min_size=15, max_size_fraction=0.4, niter=None, + do_3D=False, stitch_threshold=0.0): + """ compute masks from flows and cell probability """ + changed_device_from = None + if self.device.type == "mps" and do_3D: + models_logger.warning("MPS does not support 3D post-processing, switching to CPU") + self.device = torch.device("cpu") + changed_device_from = "mps" + Lz, Ly, Lx = shape[:3] + tic = time.time() + if do_3D: + masks = dynamics.resize_and_compute_masks( + dP, cellprob, niter=niter, cellprob_threshold=cellprob_threshold, + flow_threshold=flow_threshold, do_3D=do_3D, + min_size=min_size, max_size_fraction=max_size_fraction, + resize=shape[:3] if (np.array(dP.shape[-3:])!=np.array(shape[:3])).sum() + else None, + device=self.device) + else: + nimg = shape[0] + Ly0, Lx0 = cellprob[0].shape + resize = None if Ly0==Ly and Lx0==Lx else [Ly, Lx] + tqdm_out = utils.TqdmToLogger(models_logger, level=logging.INFO) + iterator = trange(nimg, file=tqdm_out, + mininterval=30) if nimg > 1 else range(nimg) + for i in iterator: + # turn off min_size for 3D stitching + min_size0 = min_size if stitch_threshold == 0 or nimg == 1 else -1 + outputs = dynamics.resize_and_compute_masks( + dP[:, i], cellprob[i], + niter=niter, cellprob_threshold=cellprob_threshold, + flow_threshold=flow_threshold, resize=resize, + min_size=min_size0, max_size_fraction=max_size_fraction, + device=self.device) + if i==0 and nimg > 1: + masks = np.zeros((nimg, shape[1], shape[2]), outputs.dtype) + if nimg > 1: + masks[i] = outputs + else: + masks = outputs + + if stitch_threshold > 0 and nimg > 1: + models_logger.info( + f"stitching {nimg} planes using stitch_threshold={stitch_threshold:0.3f} to make 3D masks" + ) + masks = utils.stitch3D(masks, stitch_threshold=stitch_threshold) + masks = utils.fill_holes_and_remove_small_masks( + masks, min_size=min_size) + elif nimg > 1: + models_logger.warning( + "3D stack used, but stitch_threshold=0 and do_3D=False, so masks are made per plane only" + ) + + flow_time = time.time() - tic + if shape[0] > 1: + models_logger.info("masks created in %2.2fs" % (flow_time)) + + if changed_device_from is not None: + models_logger.info("switching back to device %s" % self.device) + self.device = torch.device(changed_device_from) + return masks diff --git a/models/seg_post_model/cellpose/plot.py b/models/seg_post_model/cellpose/plot.py new file mode 100644 index 0000000000000000000000000000000000000000..99bacddc4135512cffb74862d85aedbc636fdb48 --- /dev/null +++ b/models/seg_post_model/cellpose/plot.py @@ -0,0 +1,281 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +import os +import numpy as np +import cv2 +from scipy.ndimage import gaussian_filter +from . import utils, io, transforms + +try: + import matplotlib + MATPLOTLIB_ENABLED = True +except: + MATPLOTLIB_ENABLED = False + +try: + from skimage import color + from skimage.segmentation import find_boundaries + SKIMAGE_ENABLED = True +except: + SKIMAGE_ENABLED = False + + +# modified to use sinebow color +def dx_to_circ(dP): + """Converts the optic flow representation to a circular color representation. + + Args: + dP (ndarray): Flow field components [dy, dx]. + + Returns: + ndarray: The circular color representation of the optic flow. + + """ + mag = 255 * np.clip(transforms.normalize99(np.sqrt(np.sum(dP**2, axis=0))), 0, 1.) + angles = np.arctan2(dP[1], dP[0]) + np.pi + a = 2 + mag /= a + rgb = np.zeros((*dP.shape[1:], 3), "uint8") + rgb[..., 0] = np.clip(mag * (np.cos(angles) + 1), 0, 255).astype("uint8") + rgb[..., 1] = np.clip(mag * (np.cos(angles + 2 * np.pi / 3) + 1), 0, 255).astype("uint8") + rgb[..., 2] = np.clip(mag * (np.cos(angles + 4 * np.pi / 3) + 1), 0, 255).astype("uint8") + + return rgb + + +def show_segmentation(fig, img, maski, flowi, channels=[0, 0], file_name=None): + """Plot segmentation results (like on website). + + Can save each panel of figure with file_name option. Use channels option if + img input is not an RGB image with 3 channels. + + Args: + fig (matplotlib.pyplot.figure): Figure in which to make plot. + img (ndarray): 2D or 3D array. Image input into cellpose. + maski (int, ndarray): For image k, masks[k] output from Cellpose.eval, where 0=NO masks; 1,2,...=mask labels. + flowi (int, ndarray): For image k, flows[k][0] output from Cellpose.eval (RGB of flows). + channels (list of int, optional): Channels used to run Cellpose, no need to use if image is RGB. Defaults to [0, 0]. + file_name (str, optional): File name of image. If file_name is not None, figure panels are saved. Defaults to None. + seg_norm (bool, optional): Improve cell visibility under labels. Defaults to False. + """ + if not MATPLOTLIB_ENABLED: + raise ImportError( + "matplotlib not installed, install with 'pip install matplotlib'") + ax = fig.add_subplot(1, 4, 1) + img0 = img.copy() + + if img0.shape[0] < 4: + img0 = np.transpose(img0, (1, 2, 0)) + if img0.shape[-1] < 3 or img0.ndim < 3: + img0 = image_to_rgb(img0, channels=channels) + else: + if img0.max() <= 50.0: + img0 = np.uint8(np.clip(img0, 0, 1) * 255) + ax.imshow(img0) + ax.set_title("original image") + ax.axis("off") + + outlines = utils.masks_to_outlines(maski) + + overlay = mask_overlay(img0, maski) + + ax = fig.add_subplot(1, 4, 2) + outX, outY = np.nonzero(outlines) + imgout = img0.copy() + imgout[outX, outY] = np.array([255, 0, 0]) # pure red + + ax.imshow(imgout) + ax.set_title("predicted outlines") + ax.axis("off") + + ax = fig.add_subplot(1, 4, 3) + ax.imshow(overlay) + ax.set_title("predicted masks") + ax.axis("off") + + ax = fig.add_subplot(1, 4, 4) + ax.imshow(flowi) + ax.set_title("predicted cell pose") + ax.axis("off") + + if file_name is not None: + save_path = os.path.splitext(file_name)[0] + io.imsave(save_path + "_overlay.jpg", overlay) + io.imsave(save_path + "_outlines.jpg", imgout) + io.imsave(save_path + "_flows.jpg", flowi) + + +def mask_rgb(masks, colors=None): + """Masks in random RGB colors. + + Args: + masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels. + colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range. + + Returns: + RGB (uint8, 3D array): Array of masks overlaid on grayscale image. + """ + if colors is not None: + if colors.max() > 1: + colors = np.float32(colors) + colors /= 255 + colors = utils.rgb_to_hsv(colors) + + HSV = np.zeros((masks.shape[0], masks.shape[1], 3), np.float32) + HSV[:, :, 2] = 1.0 + for n in range(int(masks.max())): + ipix = (masks == n + 1).nonzero() + if colors is None: + HSV[ipix[0], ipix[1], 0] = np.random.rand() + else: + HSV[ipix[0], ipix[1], 0] = colors[n, 0] + HSV[ipix[0], ipix[1], 1] = np.random.rand() * 0.5 + 0.5 + HSV[ipix[0], ipix[1], 2] = np.random.rand() * 0.5 + 0.5 + RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8) + return RGB + + +def mask_overlay(img, masks, colors=None): + """Overlay masks on image (set image to grayscale). + + Args: + img (int or float, 2D or 3D array): Image of size [Ly x Lx (x nchan)]. + masks (int, 2D array): Masks where 0=NO masks; 1,2,...=mask labels. + colors (int, 2D array, optional): Size [nmasks x 3], each entry is a color in 0-255 range. + + Returns: + RGB (uint8, 3D array): Array of masks overlaid on grayscale image. + """ + if colors is not None: + if colors.max() > 1: + colors = np.float32(colors) + colors /= 255 + colors = utils.rgb_to_hsv(colors) + if img.ndim > 2: + img = img.astype(np.float32).mean(axis=-1) + else: + img = img.astype(np.float32) + + HSV = np.zeros((img.shape[0], img.shape[1], 3), np.float32) + HSV[:, :, 2] = np.clip((img / 255. if img.max() > 1 else img) * 1.5, 0, 1) + hues = np.linspace(0, 1, masks.max() + 1)[np.random.permutation(masks.max())] + for n in range(int(masks.max())): + ipix = (masks == n + 1).nonzero() + if colors is None: + HSV[ipix[0], ipix[1], 0] = hues[n] + else: + HSV[ipix[0], ipix[1], 0] = colors[n, 0] + HSV[ipix[0], ipix[1], 1] = 1.0 + RGB = (utils.hsv_to_rgb(HSV) * 255).astype(np.uint8) + return RGB + + +def image_to_rgb(img0, channels=[0, 0]): + """Converts image from 2 x Ly x Lx or Ly x Lx x 2 to RGB Ly x Lx x 3. + + Args: + img0 (ndarray): Input image of shape 2 x Ly x Lx or Ly x Lx x 2. + + Returns: + ndarray: RGB image of shape Ly x Lx x 3. + + """ + img = img0.copy() + img = img.astype(np.float32) + if img.ndim < 3: + img = img[:, :, np.newaxis] + if img.shape[0] < 5: + img = np.transpose(img, (1, 2, 0)) + if channels[0] == 0: + img = img.mean(axis=-1)[:, :, np.newaxis] + for i in range(img.shape[-1]): + if np.ptp(img[:, :, i]) > 0: + img[:, :, i] = np.clip(transforms.normalize99(img[:, :, i]), 0, 1) + img[:, :, i] = np.clip(img[:, :, i], 0, 1) + img *= 255 + img = np.uint8(img) + RGB = np.zeros((img.shape[0], img.shape[1], 3), np.uint8) + if img.shape[-1] == 1: + RGB = np.tile(img, (1, 1, 3)) + else: + RGB[:, :, channels[0] - 1] = img[:, :, 0] + if channels[1] > 0: + RGB[:, :, channels[1] - 1] = img[:, :, 1] + return RGB + + +def interesting_patch(mask, bsize=130): + """ + Get patch of size bsize x bsize with most masks. + + Args: + mask (ndarray): Input mask. + bsize (int): Size of the patch. + + Returns: + tuple: Patch coordinates (y, x). + + """ + Ly, Lx = mask.shape + m = np.float32(mask > 0) + m = gaussian_filter(m, bsize / 2) + y, x = np.unravel_index(np.argmax(m), m.shape) + ycent = max(bsize // 2, min(y, Ly - bsize // 2)) + xcent = max(bsize // 2, min(x, Lx - bsize // 2)) + patch = [ + np.arange(ycent - bsize // 2, ycent + bsize // 2, 1, int), + np.arange(xcent - bsize // 2, xcent + bsize // 2, 1, int) + ] + return patch + + +def disk(med, r, Ly, Lx): + """Returns the pixels of a disk with a given radius and center. + + Args: + med (tuple): The center coordinates of the disk. + r (float): The radius of the disk. + Ly (int): The height of the image. + Lx (int): The width of the image. + + Returns: + tuple: A tuple containing the y and x coordinates of the pixels within the disk. + + """ + yy, xx = np.meshgrid(np.arange(0, Ly, 1, int), np.arange(0, Lx, 1, int), + indexing="ij") + inds = ((yy - med[0])**2 + (xx - med[1])**2)**0.5 <= r + y = yy[inds].flatten() + x = xx[inds].flatten() + return y, x + + +def outline_view(img0, maski, color=[1, 0, 0], mode="inner"): + """ + Generates a red outline overlay onto the image. + + Args: + img0 (numpy.ndarray): The input image. + maski (numpy.ndarray): The mask representing the region of interest. + color (list, optional): The color of the outline overlay. Defaults to [1, 0, 0] (red). + mode (str, optional): The mode for generating the outline. Defaults to "inner". + + Returns: + numpy.ndarray: The image with the red outline overlay. + + """ + if img0.ndim == 2: + img0 = np.stack([img0] * 3, axis=-1) + elif img0.ndim != 3: + raise ValueError("img0 not right size (must have ndim 2 or 3)") + + if SKIMAGE_ENABLED: + outlines = find_boundaries(maski, mode=mode) + else: + outlines = utils.masks_to_outlines(maski, mode=mode) + outY, outX = np.nonzero(outlines) + imgout = img0.copy() + imgout[outY, outX] = np.array(color) + + return imgout diff --git a/models/seg_post_model/cellpose/transforms.py b/models/seg_post_model/cellpose/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..3f3e6405fece41ce580dc9fdb290c384e0cfaec4 --- /dev/null +++ b/models/seg_post_model/cellpose/transforms.py @@ -0,0 +1,1261 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +import logging + +import cv2 +import numpy as np +import torch +from scipy.ndimage import gaussian_filter1d +from torch.fft import fft2, fftshift, ifft2 + +transforms_logger = logging.getLogger(__name__) + + +def _taper_mask(ly=224, lx=224, sig=7.5): + """ + Generate a taper mask. + + Args: + ly (int): The height of the mask. Default is 224. + lx (int): The width of the mask. Default is 224. + sig (float): The sigma value for the tapering function. Default is 7.5. + + Returns: + numpy.ndarray: The taper mask. + + """ + bsize = max(224, max(ly, lx)) + xm = np.arange(bsize) + xm = np.abs(xm - xm.mean()) + mask = 1 / (1 + np.exp((xm - (bsize / 2 - 20)) / sig)) + mask = mask * mask[:, np.newaxis] + mask = mask[bsize // 2 - ly // 2:bsize // 2 + ly // 2 + ly % 2, + bsize // 2 - lx // 2:bsize // 2 + lx // 2 + lx % 2] + return mask + + +def unaugment_tiles(y): + """Reverse test-time augmentations for averaging (includes flipping of flowsY and flowsX). + + Args: + y (float32): Array of shape (ntiles_y, ntiles_x, chan, Ly, Lx) where chan = (flowsY, flowsX, cell prob). + + Returns: + float32: Array of shape (ntiles_y, ntiles_x, chan, Ly, Lx). + + """ + for j in range(y.shape[0]): + for i in range(y.shape[1]): + if j % 2 == 0 and i % 2 == 1: + y[j, i] = y[j, i, :, ::-1, :] + y[j, i, 0] *= -1 + elif j % 2 == 1 and i % 2 == 0: + y[j, i] = y[j, i, :, :, ::-1] + y[j, i, 1] *= -1 + elif j % 2 == 1 and i % 2 == 1: + y[j, i] = y[j, i, :, ::-1, ::-1] + y[j, i, 0] *= -1 + y[j, i, 1] *= -1 + return y + + +def average_tiles(y, ysub, xsub, Ly, Lx): + """ + Average the results of the network over tiles. + + Args: + y (float): Output of cellpose network for each tile. Shape: [ntiles x nclasses x bsize x bsize] + ysub (list): List of arrays with start and end of tiles in Y of length ntiles + xsub (list): List of arrays with start and end of tiles in X of length ntiles + Ly (int): Size of pre-tiled image in Y (may be larger than original image if image size is less than bsize) + Lx (int): Size of pre-tiled image in X (may be larger than original image if image size is less than bsize) + + Returns: + yf (float32): Network output averaged over tiles. Shape: [nclasses x Ly x Lx] + """ + Navg = np.zeros((Ly, Lx)) + yf = np.zeros((y.shape[1], Ly, Lx), np.float32) + # taper edges of tiles + mask = _taper_mask(ly=y.shape[-2], lx=y.shape[-1]) + for j in range(len(ysub)): + yf[:, ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += y[j] * mask + Navg[ysub[j][0]:ysub[j][1], xsub[j][0]:xsub[j][1]] += mask + yf /= Navg + return yf + + +def make_tiles(imgi, bsize=224, augment=False, tile_overlap=0.1): + """Make tiles of image to run at test-time. + + Args: + imgi (np.ndarray): Array of shape (nchan, Ly, Lx) representing the input image. + bsize (int, optional): Size of tiles. Defaults to 224. + augment (bool, optional): Whether to flip tiles and set tile_overlap=2. Defaults to False. + tile_overlap (float, optional): Fraction of overlap of tiles. Defaults to 0.1. + + Returns: + A tuple containing (IMG, ysub, xsub, Ly, Lx): + IMG (np.ndarray): Array of shape (ntiles, nchan, bsize, bsize) representing the tiles. + ysub (list): List of arrays with start and end of tiles in Y of length ntiles. + xsub (list): List of arrays with start and end of tiles in X of length ntiles. + Ly (int): Height of the input image. + Lx (int): Width of the input image. + """ + nchan, Ly, Lx = imgi.shape + if augment: + bsize = np.int32(bsize) + # pad if image smaller than bsize + if Ly < bsize: + imgi = np.concatenate((imgi, np.zeros((nchan, bsize - Ly, Lx))), axis=1) + Ly = bsize + if Lx < bsize: + imgi = np.concatenate((imgi, np.zeros((nchan, Ly, bsize - Lx))), axis=2) + Ly, Lx = imgi.shape[-2:] + + # tiles overlap by half of tile size + ny = max(2, int(np.ceil(2. * Ly / bsize))) + nx = max(2, int(np.ceil(2. * Lx / bsize))) + ystart = np.linspace(0, Ly - bsize, ny).astype(int) + xstart = np.linspace(0, Lx - bsize, nx).astype(int) + + ysub = [] + xsub = [] + + # flip tiles so that overlapping segments are processed in rotation + IMG = np.zeros((len(ystart), len(xstart), nchan, bsize, bsize), np.float32) + for j in range(len(ystart)): + for i in range(len(xstart)): + ysub.append([ystart[j], ystart[j] + bsize]) + xsub.append([xstart[i], xstart[i] + bsize]) + IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]] + # flip tiles to allow for augmentation of overlapping segments + if j % 2 == 0 and i % 2 == 1: + IMG[j, i] = IMG[j, i, :, ::-1, :] + elif j % 2 == 1 and i % 2 == 0: + IMG[j, i] = IMG[j, i, :, :, ::-1] + elif j % 2 == 1 and i % 2 == 1: + IMG[j, i] = IMG[j, i, :, ::-1, ::-1] + else: + tile_overlap = min(0.5, max(0.05, tile_overlap)) + bsizeY, bsizeX = min(bsize, Ly), min(bsize, Lx) + bsizeY = np.int32(bsizeY) + bsizeX = np.int32(bsizeX) + # tiles overlap by 10% tile size + ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize)) + nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize)) + ystart = np.linspace(0, Ly - bsizeY, ny).astype(int) + xstart = np.linspace(0, Lx - bsizeX, nx).astype(int) + + ysub = [] + xsub = [] + IMG = np.zeros((len(ystart), len(xstart), nchan, bsizeY, bsizeX), np.float32) + for j in range(len(ystart)): + for i in range(len(xstart)): + ysub.append([ystart[j], ystart[j] + bsizeY]) + xsub.append([xstart[i], xstart[i] + bsizeX]) + IMG[j, i] = imgi[:, ysub[-1][0]:ysub[-1][1], xsub[-1][0]:xsub[-1][1]] + + return IMG, ysub, xsub, Ly, Lx + + +def normalize99(Y, lower=1, upper=99, copy=True, downsample=False): + """ + Normalize the image so that 0.0 corresponds to the 1st percentile and 1.0 corresponds to the 99th percentile. + + Args: + Y (ndarray): The input image (for downsample, use [Ly x Lx] or [Lz x Ly x Lx]). + lower (int, optional): The lower percentile. Defaults to 1. + upper (int, optional): The upper percentile. Defaults to 99. + copy (bool, optional): Whether to create a copy of the input image. Defaults to True. + downsample (bool, optional): Whether to downsample image to compute percentiles. Defaults to False. + + Returns: + ndarray: The normalized image. + """ + X = Y.copy() if copy else Y + X = X.astype("float32") if X.dtype!="float64" and X.dtype!="float32" else X + if downsample and X.size > 224**3: + nskip = [max(1, X.shape[i] // 224) for i in range(X.ndim)] + nskip[0] = max(1, X.shape[0] // 50) if X.ndim == 3 else nskip[0] + slc = tuple([slice(0, X.shape[i], nskip[i]) for i in range(X.ndim)]) + x01 = np.percentile(X[slc], lower) + x99 = np.percentile(X[slc], upper) + else: + x01 = np.percentile(X, lower) + x99 = np.percentile(X, upper) + if x99 - x01 > 1e-3: + X -= x01 + X /= (x99 - x01) + else: + X[:] = 0 + return X + + +def normalize99_tile(img, blocksize=100, lower=1., upper=99., tile_overlap=0.1, + norm3D=False, smooth3D=1, is3D=False): + """Compute normalization like normalize99 function but in tiles. + + Args: + img (numpy.ndarray): Array of shape (Lz x) Ly x Lx (x nchan) containing the image. + blocksize (float, optional): Size of tiles. Defaults to 100. + lower (float, optional): Lower percentile for normalization. Defaults to 1.0. + upper (float, optional): Upper percentile for normalization. Defaults to 99.0. + tile_overlap (float, optional): Fraction of overlap of tiles. Defaults to 0.1. + norm3D (bool, optional): Use same tiled normalization for each z-plane. Defaults to False. + smooth3D (int, optional): Smoothing factor for 3D normalization. Defaults to 1. + is3D (bool, optional): Set to True if image is a 3D stack. Defaults to False. + + Returns: + numpy.ndarray: Normalized image array of shape (Lz x) Ly x Lx (x nchan). + """ + is1c = True if img.ndim == 2 or (is3D and img.ndim == 3) else False + is3D = True if img.ndim > 3 or (is3D and img.ndim == 3) else False + img = img[..., np.newaxis] if is1c else img + img = img[np.newaxis, ...] if img.ndim == 3 else img + Lz, Ly, Lx, nchan = img.shape + + tile_overlap = min(0.5, max(0.05, tile_overlap)) + blocksizeY, blocksizeX = min(blocksize, Ly), min(blocksize, Lx) + blocksizeY = np.int32(blocksizeY) + blocksizeX = np.int32(blocksizeX) + # tiles overlap by 10% tile size + ny = 1 if Ly <= blocksize else int(np.ceil( + (1. + 2 * tile_overlap) * Ly / blocksize)) + nx = 1 if Lx <= blocksize else int(np.ceil( + (1. + 2 * tile_overlap) * Lx / blocksize)) + ystart = np.linspace(0, Ly - blocksizeY, ny).astype(int) + xstart = np.linspace(0, Lx - blocksizeX, nx).astype(int) + ysub = [] + xsub = [] + for j in range(len(ystart)): + for i in range(len(xstart)): + ysub.append([ystart[j], ystart[j] + blocksizeY]) + xsub.append([xstart[i], xstart[i] + blocksizeX]) + + x01_tiles_z = [] + x99_tiles_z = [] + for z in range(Lz): + IMG = np.zeros((len(ystart), len(xstart), blocksizeY, blocksizeX, nchan), + "float32") + k = 0 + for j in range(len(ystart)): + for i in range(len(xstart)): + IMG[j, i] = img[z, ysub[k][0]:ysub[k][1], xsub[k][0]:xsub[k][1], :] + k += 1 + x01_tiles = np.percentile(IMG, lower, axis=(-3, -2)) + x99_tiles = np.percentile(IMG, upper, axis=(-3, -2)) + + # fill areas with small differences with neighboring squares + to_fill = np.zeros(x01_tiles.shape[:2], "bool") + for c in range(nchan): + to_fill = x99_tiles[:, :, c] - x01_tiles[:, :, c] < +1e-3 + if to_fill.sum() > 0 and to_fill.sum() < x99_tiles[:, :, c].size: + fill_vals = np.nonzero(to_fill) + fill_neigh = np.nonzero(~to_fill) + nearest_neigh = ( + (fill_vals[0] - fill_neigh[0][:, np.newaxis])**2 + + (fill_vals[1] - fill_neigh[1][:, np.newaxis])**2).argmin(axis=0) + x01_tiles[fill_vals[0], fill_vals[1], + c] = x01_tiles[fill_neigh[0][nearest_neigh], + fill_neigh[1][nearest_neigh], c] + x99_tiles[fill_vals[0], fill_vals[1], + c] = x99_tiles[fill_neigh[0][nearest_neigh], + fill_neigh[1][nearest_neigh], c] + elif to_fill.sum() > 0 and to_fill.sum() == x99_tiles[:, :, c].size: + x01_tiles[:, :, c] = 0 + x99_tiles[:, :, c] = 1 + x01_tiles_z.append(x01_tiles) + x99_tiles_z.append(x99_tiles) + + x01_tiles_z = np.array(x01_tiles_z) + x99_tiles_z = np.array(x99_tiles_z) + # do not smooth over z-axis if not normalizing separately per plane + for a in range(2): + x01_tiles_z = gaussian_filter1d(x01_tiles_z, 1, axis=a) + x99_tiles_z = gaussian_filter1d(x99_tiles_z, 1, axis=a) + if norm3D: + smooth3D = 1 if smooth3D == 0 else smooth3D + x01_tiles_z = gaussian_filter1d(x01_tiles_z, smooth3D, axis=a) + x99_tiles_z = gaussian_filter1d(x99_tiles_z, smooth3D, axis=a) + + if not norm3D and Lz > 1: + x01 = np.zeros((len(x01_tiles_z), Ly, Lx, nchan), "float32") + x99 = np.zeros((len(x01_tiles_z), Ly, Lx, nchan), "float32") + for z in range(Lz): + x01_rsz = cv2.resize(x01_tiles_z[z], (Lx, Ly), + interpolation=cv2.INTER_LINEAR) + x01[z] = x01_rsz[..., np.newaxis] if nchan == 1 else x01_rsz + x99_rsz = cv2.resize(x99_tiles_z[z], (Lx, Ly), + interpolation=cv2.INTER_LINEAR) + x99[z] = x99_rsz[..., np.newaxis] if nchan == 1 else x01_rsz + if (x99 - x01).min() < 1e-3: + raise ZeroDivisionError( + "cannot use norm3D=False with tile_norm, sample is too sparse; set norm3D=True or tile_norm=0" + ) + else: + x01 = cv2.resize(x01_tiles_z.mean(axis=0), (Lx, Ly), + interpolation=cv2.INTER_LINEAR) + x99 = cv2.resize(x99_tiles_z.mean(axis=0), (Lx, Ly), + interpolation=cv2.INTER_LINEAR) + if x01.ndim < 3: + x01 = x01[..., np.newaxis] + x99 = x99[..., np.newaxis] + + if is1c: + img, x01, x99 = img.squeeze(), x01.squeeze(), x99.squeeze() + elif not is3D: + img, x01, x99 = img[0], x01[0], x99[0] + + # normalize + img -= x01 + img /= (x99 - x01) + + return img + + +def gaussian_kernel(sigma, Ly, Lx, device=torch.device("cpu")): + """ + Generates a 2D Gaussian kernel. + + Args: + sigma (float): Standard deviation of the Gaussian distribution. + Ly (int): Number of pixels in the y-axis. + Lx (int): Number of pixels in the x-axis. + device (torch.device, optional): Device to store the kernel tensor. Defaults to torch.device("cpu"). + + Returns: + torch.Tensor: 2D Gaussian kernel tensor. + + """ + y = torch.linspace(-Ly / 2, Ly / 2 + 1, Ly, device=device) + x = torch.linspace(-Ly / 2, Ly / 2 + 1, Lx, device=device) + y, x = torch.meshgrid(y, x, indexing="ij") + kernel = torch.exp(-(y**2 + x**2) / (2 * sigma**2)) + kernel /= kernel.sum() + return kernel + + +def smooth_sharpen_img(img, smooth_radius=6, sharpen_radius=12, + device=torch.device("cpu"), is3D=False): + """Sharpen blurry images with surround subtraction and/or smooth noisy images. + + Args: + img (float32): Array that's (Lz x) Ly x Lx (x nchan). + smooth_radius (float, optional): Size of gaussian smoothing filter, recommended to be 1/10-1/4 of cell diameter + (if also sharpening, should be 2-3x smaller than sharpen_radius). Defaults to 6. + sharpen_radius (float, optional): Size of gaussian surround filter, recommended to be 1/8-1/2 of cell diameter + (if also smoothing, should be 2-3x larger than smooth_radius). Defaults to 12. + device (torch.device, optional): Device on which to perform sharpening. + Will be faster on GPU but need to ensure GPU has RAM for image. Defaults to torch.device("cpu"). + is3D (bool, optional): If image is 3D stack (only necessary to set if img.ndim==3). Defaults to False. + + Returns: + img_sharpen (float32): Array that's (Lz x) Ly x Lx (x nchan). + """ + img_sharpen = torch.from_numpy(img.astype("float32")).to(device) + shape = img_sharpen.shape + + is1c = True if img_sharpen.ndim == 2 or (is3D and img_sharpen.ndim == 3) else False + is3D = True if img_sharpen.ndim > 3 or (is3D and img_sharpen.ndim == 3) else False + img_sharpen = img_sharpen.unsqueeze(-1) if is1c else img_sharpen + img_sharpen = img_sharpen.unsqueeze(0) if img_sharpen.ndim == 3 else img_sharpen + Lz, Ly, Lx, nchan = img_sharpen.shape + + if smooth_radius > 0: + kernel = gaussian_kernel(smooth_radius, Ly, Lx, device=device) + if sharpen_radius > 0: + kernel += -1 * gaussian_kernel(sharpen_radius, Ly, Lx, device=device) + elif sharpen_radius > 0: + kernel = -1 * gaussian_kernel(sharpen_radius, Ly, Lx, device=device) + kernel[Ly // 2, Lx // 2] = 1 + + fhp = fft2(kernel) + for z in range(Lz): + for c in range(nchan): + img_filt = torch.real(ifft2( + fft2(img_sharpen[z, :, :, c]) * torch.conj(fhp))) + img_filt = fftshift(img_filt) + img_sharpen[z, :, :, c] = img_filt + + img_sharpen = img_sharpen.reshape(shape) + return img_sharpen.cpu().numpy() + + +def move_axis(img, m_axis=-1, first=True): + """ move axis m_axis to first or last position """ + if m_axis == -1: + m_axis = img.ndim - 1 + m_axis = min(img.ndim - 1, m_axis) + axes = np.arange(0, img.ndim) + if first: + axes[1:m_axis + 1] = axes[:m_axis] + axes[0] = m_axis + else: + axes[m_axis:-1] = axes[m_axis + 1:] + axes[-1] = m_axis + img = img.transpose(tuple(axes)) + return img + + +def move_min_dim(img, force=False): + """Move the minimum dimension last as channels if it is less than 10 or force is True. + + Args: + img (ndarray): The input image. + force (bool, optional): If True, the minimum dimension will always be moved. + Defaults to False. + + Returns: + ndarray: The image with the minimum dimension moved to the last axis as channels. + """ + if len(img.shape) > 2: + min_dim = min(img.shape) + if min_dim < 10 or force: + if img.shape[-1] == min_dim: + channel_axis = -1 + else: + channel_axis = (img.shape).index(min_dim) + img = move_axis(img, m_axis=channel_axis, first=False) + return img + + +def update_axis(m_axis, to_squeeze, ndim): + """ + Squeeze the axis value based on the given parameters. + + Args: + m_axis (int): The current axis value. + to_squeeze (numpy.ndarray): An array of indices to squeeze. + ndim (int): The number of dimensions. + + Returns: + m_axis (int or None): The updated axis value. + """ + if m_axis == -1: + m_axis = ndim - 1 + if (to_squeeze == m_axis).sum() == 1: + m_axis = None + else: + inds = np.ones(ndim, bool) + inds[to_squeeze] = False + m_axis = np.nonzero(np.arange(0, ndim)[inds] == m_axis)[0] + if len(m_axis) > 0: + m_axis = m_axis[0] + else: + m_axis = None + return m_axis + + +def _convert_image_3d(x, channel_axis=None, z_axis=None): + """ + Convert a 3D or 4D image array to have dimensions ordered as (Z, X, Y, C). + + Arrays of ndim=3 are assumed to be grayscale and must be specified with z_axis. + Arrays of ndim=4 must have both `channel_axis` and `z_axis` specified. + + Args: + x (numpy.ndarray): Input image array. Must be either 3D (assumed to be grayscale 3D) or 4D. + channel_axis (int): The axis index corresponding to the channel dimension in the input array. \ + Must be specified for 4D images. + z_axis (int): The axis index corresponding to the depth (Z) dimension in the input array. \ + Must be specified for both 3D and 4D images. + + Returns: + numpy.ndarray: A 4D image array with dimensions ordered as (Z, X, Y, C), where C is the channel + dimension. If the input has fewer than 3 channels, the output will be padded with zeros to \ + have 3 channels. If the input has more than 3 channels, only the first 3 channels will be retained. + + Raises: + ValueError: If `z_axis` is not specified for 3D images. If either `channel_axis` or `z_axis` \ + is not specified for 4D images. If the input image does not have 3 or 4 dimensions. + + Notes: + - For 3D images (ndim=3), the function assumes the input is grayscale and adds a singleton channel dimension. + - The function reorders the dimensions of the input array to ensure the output has the desired (Z, X, Y, C) order. + - If the number of channels is not equal to 3, the function either truncates or pads the \ + channels to ensure the output has exactly 3 channels. + """ + + if x.ndim < 3: + raise ValueError(f"Input image must have at least 3 dimensions, input shape: {x.shape}, ndim={x.ndim}") + + if z_axis is not None and z_axis < 0: + z_axis += x.ndim + + # if image is ndim==3, assume it is greyscale 3D and use provided z_axis + if x.ndim == 3 and z_axis is not None: + # add in channel axis + x = x[..., np.newaxis] + channel_axis = 3 + elif x.ndim == 3 and z_axis is None: + raise ValueError("z_axis must be specified when segmenting 3D images of ndim=3") + + + if channel_axis is None or z_axis is None: + raise ValueError("For 4D images, both `channel_axis` and `z_axis` must be explicitly specified. Please provide values for both parameters.") + if channel_axis is not None and channel_axis < 0: + channel_axis += x.ndim + if channel_axis is None or channel_axis >= x.ndim: + raise IndexError(f"channel_axis {channel_axis} is out of bounds for input array with {x.ndim} dimensions") + assert x.ndim == 4, f"input image must have ndim == 4, ndim={x.ndim}" + + x_dim_shapes = list(x.shape) + num_z_layers = x_dim_shapes[z_axis] + num_channels = x_dim_shapes[channel_axis] + x_xy_axes = [i for i in range(x.ndim)] + + # need to remove the z and channels from the shapes: + # delete the one with the bigger index first + if z_axis > channel_axis: + del x_dim_shapes[z_axis] + del x_dim_shapes[channel_axis] + + del x_xy_axes[z_axis] + del x_xy_axes[channel_axis] + + else: + del x_dim_shapes[channel_axis] + del x_dim_shapes[z_axis] + + del x_xy_axes[channel_axis] + del x_xy_axes[z_axis] + + x = x.transpose((z_axis, x_xy_axes[0], x_xy_axes[1], channel_axis)) + + # Handle cases with not 3 channels: + if num_channels != 3: + x_chans_to_copy = min(3, num_channels) + + if num_channels > 3: + transforms_logger.warning("more than 3 channels provided, only segmenting on first 3 channels") + x = x[..., :x_chans_to_copy] + else: + # less than 3 channels: pad up to + pad_width = [(0, 0), (0, 0), (0, 0), (0, 3 - x_chans_to_copy)] + x = np.pad(x, pad_width, mode='constant', constant_values=0) + + return x + + +def convert_image(x, channel_axis=None, z_axis=None, do_3D=False): + """Converts the image to have the z-axis first, channels last. Image will be converted to 3 channels if it is not already. + If more than 3 channels are provided, only the first 3 channels will be used. + + Accepts: + - 2D images with no channel dimension: `z_axis` and `channel_axis` must be `None` + - 2D images with channel dimension: `channel_axis` will be guessed between first or last axis, can also specify `channel_axis`. `z_axis` must be `None` + - 3D images with or without channels: + + Args: + x (numpy.ndarray or torch.Tensor): The input image. + channel_axis (int or None): The axis of the channels in the input image. If None, the axis is determined automatically. + z_axis (int or None): The axis of the z-dimension in the input image. If None, the axis is determined automatically. + do_3D (bool): Whether to process the image in 3D mode. Defaults to False. + + Returns: + numpy.ndarray: The converted image. + + Raises: + ValueError: If the input image is 2D and do_3D is True. + ValueError: If the input image is 4D and do_3D is False. + """ + + # check if image is a torch array instead of numpy array, convert to numpy + ndim = x.ndim + if torch.is_tensor(x): + transforms_logger.warning("torch array used as input, converting to numpy") + x = x.cpu().numpy() + + # should be 2D + if z_axis is not None and not do_3D: + raise ValueError("2D image provided, but z_axis is not None. Set z_axis=None to process 2D images of ndim=2 or 3.") + + # make sure that channel_axis and z_axis are specified if 3D + if ndim == 4 and not do_3D: + raise ValueError("3D input image provided, but do_3D is False. Set do_3D=True to process 3D images. ndims=4") + + # make sure that channel_axis and z_axis are specified if 3D + if do_3D: + return _convert_image_3d(x, channel_axis=channel_axis, z_axis=z_axis) + + ######################## 2D reshaping ######################## + # if user specifies channel axis, return early + if channel_axis is not None: + if ndim == 2: + raise ValueError("2D image provided, but channel_axis is not None. Set channel_axis=None to process 2D images of ndim=2.") + + # Put channel axis last: + # Find the indices of the dims that need to be put in dim 0 and 1 + n_channels = x.shape[channel_axis] + x_shape_dims = list(x.shape) + del x_shape_dims[channel_axis] + dimension_indicies = [i for i in range(x.ndim)] + del dimension_indicies[channel_axis] + + x = x.transpose((dimension_indicies[0], dimension_indicies[1], channel_axis)) + + if n_channels != 3: + x_chans_to_copy = min(3, n_channels) + + if n_channels > 3: + transforms_logger.warning("more than 3 channels provided, only segmenting on first 3 channels") + x = x[..., :x_chans_to_copy] + else: + x_out = np.zeros((x_shape_dims[0], x_shape_dims[1], 3), dtype=x.dtype) + x_out[..., :x_chans_to_copy] = x[...] + x = x_out + del x_out + + return x + + # do image padding and channel conversion + if ndim == 2: + # grayscale image, make 3 channels + x_out = np.zeros((x.shape[0], x.shape[1], 3), dtype=x.dtype) + x_out[..., 0] = x + x = x_out + del x_out + elif ndim == 3: + # assume 2d with channels + # find dim with smaller size between first and last dims + move_channel_axis = x.shape[0] < x.shape[2] + if move_channel_axis: + x = x.transpose((1, 2, 0)) + + # zero padding up to 3 channels: + num_channels = x.shape[-1] + if num_channels > 3: + transforms_logger.warning("Found more than 3 channels, only using first 3") + num_channels = 3 + x_out = np.zeros((x.shape[0], x.shape[1], 3), dtype=x.dtype) + x_out[..., :num_channels] = x[..., :num_channels] + x = x_out + del x_out + else: + # something is wrong: yell + expected_shapes = "2D (H, W), 3D (H, W, C), or 4D (Z, H, W, C)" + transforms_logger.critical(f"ERROR: Unexpected image shape: {str(x.shape)}. Expected shapes: {expected_shapes}") + raise ValueError(f"ERROR: Unexpected image shape: {str(x.shape)}. Expected shapes: {expected_shapes}") + + return x + + +def normalize_img(img, normalize=True, norm3D=True, invert=False, lowhigh=None, + percentile=(1., 99.), sharpen_radius=0, smooth_radius=0, + tile_norm_blocksize=0, tile_norm_smooth3D=1, axis=-1): + """Normalize each channel of the image with optional inversion, smoothing, and sharpening. + + Args: + img (ndarray): The input image. It should have at least 3 dimensions. + If it is 4-dimensional, it assumes the first non-channel axis is the Z dimension. + normalize (bool, optional): Whether to perform normalization. Defaults to True. + norm3D (bool, optional): Whether to normalize in 3D. If True, the entire 3D stack will + be normalized per channel. If False, normalization is applied per Z-slice. Defaults to False. + invert (bool, optional): Whether to invert the image. Useful if cells are dark instead of bright. + Defaults to False. + lowhigh (tuple or ndarray, optional): The lower and upper bounds for normalization. + Can be a tuple of two values (applied to all channels) or an array of shape (nchan, 2) + for per-channel normalization. Incompatible with smoothing and sharpening. + Defaults to None. + percentile (tuple, optional): The lower and upper percentiles for normalization. If provided, it should be + a tuple of two values. Each value should be between 0 and 100. Defaults to (1.0, 99.0). + sharpen_radius (int, optional): The radius for sharpening the image. Defaults to 0. + smooth_radius (int, optional): The radius for smoothing the image. Defaults to 0. + tile_norm_blocksize (int, optional): The block size for tile-based normalization. Defaults to 0. + tile_norm_smooth3D (int, optional): The smoothness factor for tile-based normalization in 3D. Defaults to 1. + axis (int, optional): The channel axis to loop over for normalization. Defaults to -1. + + Returns: + ndarray: The normalized image of the same size. + + Raises: + ValueError: If the image has less than 3 dimensions. + ValueError: If the provided lowhigh or percentile values are invalid. + ValueError: If the image is inverted without normalization. + + """ + if img.ndim < 3: + error_message = "Image needs to have at least 3 dimensions" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + img_norm = img if img.dtype=="float32" else img.astype(np.float32) + if axis != -1 and axis != img_norm.ndim - 1: + img_norm = np.moveaxis(img_norm, axis, -1) # Move channel axis to last + + nchan = img_norm.shape[-1] + + # Validate and handle lowhigh bounds + if lowhigh is not None: + lowhigh = np.array(lowhigh) + if lowhigh.shape == (2,): + lowhigh = np.tile(lowhigh, (nchan, 1)) # Expand to per-channel bounds + elif lowhigh.shape != (nchan, 2): + error_message = "`lowhigh` must have shape (2,) or (nchan, 2)" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + # Validate percentile + if percentile is None: + percentile = (1.0, 99.0) + elif not (0 <= percentile[0] < percentile[1] <= 100): + error_message = "Invalid percentile range, should be between 0 and 100" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + # Apply normalization based on lowhigh or percentile + cgood = np.zeros(nchan, "bool") + if lowhigh is not None: + for c in range(nchan): + lower = lowhigh[c, 0] + upper = lowhigh[c, 1] + img_norm[..., c] -= lower + img_norm[..., c] /= (upper - lower) + cgood[c] = True + else: + # Apply sharpening and smoothing if specified + if sharpen_radius > 0 or smooth_radius > 0: + img_norm = smooth_sharpen_img( + img_norm, sharpen_radius=sharpen_radius, smooth_radius=smooth_radius + ) + + # Apply tile-based normalization or standard normalization + if tile_norm_blocksize > 0: + img_norm = normalize99_tile( + img_norm, + blocksize=tile_norm_blocksize, + lower=percentile[0], + upper=percentile[1], + smooth3D=tile_norm_smooth3D, + norm3D=norm3D, + ) + cgood[:] = True + elif normalize: + if img_norm.ndim == 3 or norm3D: # i.e. if YXC, or ZYXC with norm3D=True + for c in range(nchan): + if np.ptp(img_norm[..., c]) > 0.: + img_norm[..., c] = normalize99( + img_norm[..., c], + lower=percentile[0], + upper=percentile[1], + copy=False, downsample=True, + ) + cgood[c] = True + else: # i.e. if ZYXC with norm3D=False then per Z-slice + for z in range(img_norm.shape[0]): + for c in range(nchan): + if np.ptp(img_norm[z, ..., c]) > 0.: + img_norm[z, ..., c] = normalize99( + img_norm[z, ..., c], + lower=percentile[0], + upper=percentile[1], + copy=False, downsample=True, + ) + cgood[c] = True + + + if invert: + if lowhigh is not None or tile_norm_blocksize > 0 or normalize: + for c in range(nchan): + if cgood[c]: + img_norm[..., c] = 1 - img_norm[..., c] + else: + error_message = "Cannot invert image without normalization" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + # Move channel axis back to the original position + if axis != -1 and axis != img_norm.ndim - 1: + img_norm = np.moveaxis(img_norm, -1, axis) + + # The transformer can get confused if a channel is all 1's instead of all 0's: + for i, chan_did_normalize in enumerate(cgood): + if not chan_did_normalize: + if img_norm.ndim == 3: + img_norm[:, :, i] = 0 + if img_norm.ndim == 4: + img_norm[:, :, :, i] = 0 + + return img_norm + +def resize_safe(img, Ly, Lx, interpolation=cv2.INTER_LINEAR): + """OpenCV resize function does not support uint32. + + This function converts the image to float32 before resizing and then converts it back to uint32. Not safe! + References issue: https://github.com/MouseLand/cellpose/issues/937 + + Implications: + * Runtime: Runtime increases by 5x-50x due to type casting. However, with resizing being very efficient, this is not + a big issue. A 10,000x10,000 image takes 0.47s instead of 0.016s to cast and resize on 32 cores on GPU. + * Memory: However, memory usage increases. Not tested by how much. + + Args: + img (ndarray): Image of size [Ly x Lx]. + Ly (int): Desired height of the resized image. + Lx (int): Desired width of the resized image. + interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR. + + Returns: + ndarray: Resized image of size [Ly x Lx]. + + """ + + # cast image + cast = img.dtype == np.uint32 + if cast: + img = img.astype(np.float32) + + # resize + img = cv2.resize(img, (Lx, Ly), interpolation=interpolation) + + # cast back + if cast: + img = img.round().astype(np.uint32) + + return img + + +def resize_image(img0, Ly=None, Lx=None, rsz=None, interpolation=cv2.INTER_LINEAR, + no_channels=False): + """Resize image for computing flows / unresize for computing dynamics. + + Args: + img0 (ndarray): Image of size [Y x X x nchan] or [Lz x Y x X x nchan] or [Lz x Y x X]. + Ly (int, optional): Desired height of the resized image. Defaults to None. + Lx (int, optional): Desired width of the resized image. Defaults to None. + rsz (float, optional): Resize coefficient(s) for the image. If Ly is None, rsz is used. Defaults to None. + interpolation (int, optional): OpenCV interpolation method. Defaults to cv2.INTER_LINEAR. + no_channels (bool, optional): Flag indicating whether to treat the third dimension as a channel. + Defaults to False. + + Returns: + ndarray: Resized image of size [Ly x Lx x nchan] or [Lz x Ly x Lx x nchan]. + + Raises: + ValueError: If Ly is None and rsz is None. + + """ + if Ly is None and rsz is None: + error_message = "must give size to resize to or factor to use for resizing" + transforms_logger.critical(error_message) + raise ValueError(error_message) + + if Ly is None: + # determine Ly and Lx using rsz + if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray): + rsz = [rsz, rsz] + if no_channels: + Ly = int(img0.shape[-2] * rsz[-2]) + Lx = int(img0.shape[-1] * rsz[-1]) + else: + Ly = int(img0.shape[-3] * rsz[-2]) + Lx = int(img0.shape[-2] * rsz[-1]) + + # no_channels useful for z-stacks, so the third dimension is not treated as a channel + # but if this is called for grayscale images, they first become [Ly,Lx,2] so ndim=3 but + if (img0.ndim > 2 and no_channels) or (img0.ndim == 4 and not no_channels): + if Ly == 0 or Lx == 0: + raise ValueError( + "anisotropy too high / low -- not enough pixels to resize to ratio") + for i, img in enumerate(img0): + imgi = resize_safe(img, Ly, Lx, interpolation=interpolation) + if i==0: + if no_channels: + imgs = np.zeros((img0.shape[0], Ly, Lx), imgi.dtype) + else: + imgs = np.zeros((img0.shape[0], Ly, Lx, img0.shape[-1]), imgi.dtype) + imgs[i] = imgi if imgi.ndim > 2 or no_channels else imgi[..., np.newaxis] + else: + imgs = resize_safe(img0, Ly, Lx, interpolation=interpolation) + return imgs + +def get_pad_yx(Ly, Lx, div=16, extra=1, min_size=None): + if min_size is None or Ly >= min_size[-2]: + Lpad = int(div * np.ceil(Ly / div) - Ly) + else: + Lpad = min_size[-2] - Ly + ypad1 = extra * div // 2 + Lpad // 2 + ypad2 = extra * div // 2 + Lpad - Lpad // 2 + if min_size is None or Lx >= min_size[-1]: + Lpad = int(div * np.ceil(Lx / div) - Lx) + else: + Lpad = min_size[-1] - Lx + xpad1 = extra * div // 2 + Lpad // 2 + xpad2 = extra * div // 2 + Lpad - Lpad // 2 + + return ypad1, ypad2, xpad1, xpad2 + + +def pad_image_ND(img0, div=16, extra=1, min_size=None, zpad=False): + """Pad image for test-time so that its dimensions are a multiple of 16 (2D or 3D). + + Args: + img0 (ndarray): Image of size [nchan (x Lz) x Ly x Lx]. + div (int, optional): Divisor for padding. Defaults to 16. + extra (int, optional): Extra padding. Defaults to 1. + min_size (tuple, optional): Minimum size of the image. Defaults to None. + + Returns: + A tuple containing (I, ysub, xsub) or (I, ysub, xsub, zsub), I is padded image, -sub are ranges of pixels in the padded image corresponding to img0. + + """ + Ly, Lx = img0.shape[-2:] + ypad1, ypad2, xpad1, xpad2 = get_pad_yx(Ly, Lx, div=div, extra=extra, min_size=min_size) + + if img0.ndim > 3: + if zpad: + Lpad = int(div * np.ceil(img0.shape[-3] / div) - img0.shape[-3]) + zpad1 = extra * div // 2 + Lpad // 2 + zpad2 = extra * div // 2 + Lpad - Lpad // 2 + else: + zpad1, zpad2 = 0, 0 + pads = np.array([[0, 0], [zpad1, zpad2], [ypad1, ypad2], [xpad1, xpad2]]) + else: + pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]]) + + I = np.pad(img0, pads, mode="constant") + + ysub = np.arange(ypad1, ypad1 + Ly) + xsub = np.arange(xpad1, xpad1 + Lx) + if zpad: + zsub = np.arange(zpad1, zpad1 + img0.shape[-3]) + return I, ysub, xsub, zsub + else: + return I, ysub, xsub + + +def random_rotate_and_resize(X, Y=None, scale_range=1., xy=(224, 224), do_3D=False, + zcrop=48, do_flip=True, rotate=True, rescale=None, unet=False, + random_per_image=True): + """Augmentation by random rotation and resizing. + + Args: + X (list of ND-arrays, float): List of image arrays of size [nchan x Ly x Lx] or [Ly x Lx]. + Y (list of ND-arrays, float, optional): List of image labels of size [nlabels x Ly x Lx] or [Ly x Lx]. + The 1st channel of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation). + If Y.shape[0]==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow]. + If unet, second channel is dist_to_bound. Defaults to None. + scale_range (float, optional): Range of resizing of images for augmentation. + Images are resized by (1-scale_range/2) + scale_range * np.random.rand(). Defaults to 1.0. + xy (tuple, int, optional): Size of transformed images to return. Defaults to (224,224). + do_flip (bool, optional): Whether or not to flip images horizontally. Defaults to True. + rotate (bool, optional): Whether or not to rotate images. Defaults to True. + rescale (array, float, optional): How much to resize images by before performing augmentations. Defaults to None. + unet (bool, optional): Whether or not to use unet. Defaults to False. + random_per_image (bool, optional): Different random rotate and resize per image. Defaults to True. + + Returns: + A tuple containing (imgi, lbl, scale): imgi (ND-array, float): Transformed images in array [nimg x nchan x xy[0] x xy[1]]; + lbl (ND-array, float): Transformed labels in array [nimg x nchan x xy[0] x xy[1]]; + scale (array, float): Amount each image was resized by. + """ + scale_range = max(0, min(2, float(scale_range))) if scale_range is not None else scale_range + nimg = len(X) + if X[0].ndim > 2: + nchan = X[0].shape[0] + else: + nchan = 1 + if do_3D and X[0].ndim > 3: + shape = (zcrop, xy[0], xy[1]) + else: + shape = (xy[0], xy[1]) + imgi = np.zeros((nimg, nchan, *shape), "float32") + + lbl = [] + if Y is not None: + if Y[0].ndim > 2: + nt = Y[0].shape[0] + else: + nt = 1 + lbl = np.zeros((nimg, nt, *shape), np.float32) + + scale = np.ones(nimg, np.float32) + + for n in range(nimg): + + if random_per_image or n == 0: + Ly, Lx = X[n].shape[-2:] + # generate random augmentation parameters + flip = np.random.rand() > .5 + theta = np.random.rand() * np.pi * 2 if rotate else 0. + if scale_range is None: + scale[n] = 2 ** (4 * np.random.rand() - 2) + else: + scale[n] = (1 - scale_range / 2) + scale_range * np.random.rand() + if rescale is not None: + scale[n] *= 1. / rescale[n] + dxy = np.maximum(0, np.array([Lx * scale[n] - xy[1], + Ly * scale[n] - xy[0]])) + dxy = (np.random.rand(2,) - .5) * dxy + + # create affine transform + cc = np.array([Lx / 2, Ly / 2]) + cc1 = cc - np.array([Lx - xy[1], Ly - xy[0]]) / 2 + dxy + pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])]) + pts2 = np.float32([ + cc1, + cc1 + scale[n] * np.array([np.cos(theta), np.sin(theta)]), + cc1 + scale[n] * + np.array([np.cos(np.pi / 2 + theta), + np.sin(np.pi / 2 + theta)]) + ]) + M = cv2.getAffineTransform(pts1, pts2) + + img = X[n].copy() + if Y is not None: + labels = Y[n].copy() + if labels.ndim < 3: + labels = labels[np.newaxis, :, :] + + if do_3D: + Lz = X[n].shape[-3] + flip_z = np.random.rand() > .5 + lz = int(np.round(zcrop / scale[n])) + iz = np.random.randint(0, Lz - lz) + img = img[:,iz:iz + lz,:,:] + if Y is not None: + labels = labels[:,iz:iz + lz,:,:] + + if do_flip: + if flip: + img = img[..., ::-1] + if Y is not None: + labels = labels[..., ::-1] + if nt > 1 and not unet: + labels[-1] = -labels[-1] + if do_3D and flip_z: + img = img[:, ::-1] + if Y is not None: + labels = labels[:,::-1] + if nt > 1 and not unet: + labels[-3] = -labels[-3] + + for k in range(nchan): + if do_3D: + img0 = np.zeros((lz, xy[0], xy[1]), "float32") + for z in range(lz): + I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]), + flags=cv2.INTER_LINEAR) + img0[z] = I + if scale[n] != 1.0: + for y in range(imgi.shape[-2]): + imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop), + interpolation=cv2.INTER_LINEAR) + else: + imgi[n, k] = img0 + else: + I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR) + imgi[n, k] = I + + if Y is not None: + for k in range(nt): + flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR + if do_3D: + lbl0 = np.zeros((lz, xy[0], xy[1]), "float32") + for z in range(lz): + I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]), + flags=flag) + lbl0[z] = I + if scale[n] != 1.0: + for y in range(lbl.shape[-2]): + lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop), + interpolation=flag) + else: + lbl[n, k] = lbl0 + else: + lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag) + + if nt > 1 and not unet: + v1 = lbl[n, -1].copy() + v2 = lbl[n, -2].copy() + lbl[n, -2] = (-v1 * np.sin(-theta) + v2 * np.cos(-theta)) + lbl[n, -1] = (v1 * np.cos(-theta) + v2 * np.sin(-theta)) + + return imgi, lbl, scale + + +def random_rotate_and_resize_with_feat(X, Y=None, feat=None, scale_range=1., xy=(224, 224), do_3D=False, + zcrop=48, do_flip=True, rotate=True, rescale=None, unet=False, + random_per_image=True): + """Augmentation by random rotation and resizing. + + Args: + X (list of ND-arrays, float): List of image arrays of size [nchan x Ly x Lx] or [Ly x Lx]. + Y (list of ND-arrays, float, optional): List of image labels of size [nlabels x Ly x Lx] or [Ly x Lx]. + The 1st channel of Y is always nearest-neighbor interpolated (assumed to be masks or 0-1 representation). + If Y.shape[0]==3 and not unet, then the labels are assumed to be [cell probability, Y flow, X flow]. + If unet, second channel is dist_to_bound. Defaults to None. + scale_range (float, optional): Range of resizing of images for augmentation. + Images are resized by (1-scale_range/2) + scale_range * np.random.rand(). Defaults to 1.0. + xy (tuple, int, optional): Size of transformed images to return. Defaults to (224,224). + do_flip (bool, optional): Whether or not to flip images horizontally. Defaults to True. + rotate (bool, optional): Whether or not to rotate images. Defaults to True. + rescale (array, float, optional): How much to resize images by before performing augmentations. Defaults to None. + unet (bool, optional): Whether or not to use unet. Defaults to False. + random_per_image (bool, optional): Different random rotate and resize per image. Defaults to True. + + Returns: + A tuple containing (imgi, lbl, scale): imgi (ND-array, float): Transformed images in array [nimg x nchan x xy[0] x xy[1]]; + lbl (ND-array, float): Transformed labels in array [nimg x nchan x xy[0] x xy[1]]; + scale (array, float): Amount each image was resized by. + """ + scale_range = max(0, min(2, float(scale_range))) if scale_range is not None else scale_range + nimg = len(X) + if X[0].ndim > 2: + nchan = X[0].shape[0] + else: + nchan = 1 + if do_3D and X[0].ndim > 3: + shape = (zcrop, xy[0], xy[1]) + else: + shape = (xy[0], xy[1]) + imgi = np.zeros((nimg, nchan, *shape), "float32") + + lbl = [] + if Y is not None: + if Y[0].ndim > 2: + nt = Y[0].shape[0] + else: + nt = 1 + lbl = np.zeros((nimg, nt, *shape), np.float32) + + if feat is not None: + if feat[0].ndim > 2: + nf = feat[0].shape[0] + else: + nf = 1 + feat_out = np.zeros((nimg, nf, *shape), "float32") + + scale = np.ones(nimg, np.float32) + + for n in range(nimg): + + if random_per_image or n == 0: + Ly, Lx = X[n].shape[-2:] + # generate random augmentation parameters + flip = np.random.rand() > .5 + theta = np.random.rand() * np.pi * 2 if rotate else 0. + if scale_range is None: + scale[n] = 2 ** (4 * np.random.rand() - 2) + else: + scale[n] = (1 - scale_range / 2) + scale_range * np.random.rand() + if rescale is not None: + scale[n] *= 1. / rescale[n] + dxy = np.maximum(0, np.array([Lx * scale[n] - xy[1], + Ly * scale[n] - xy[0]])) + dxy = (np.random.rand(2,) - .5) * dxy + + # create affine transform + cc = np.array([Lx / 2, Ly / 2]) + cc1 = cc - np.array([Lx - xy[1], Ly - xy[0]]) / 2 + dxy + pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])]) + pts2 = np.float32([ + cc1, + cc1 + scale[n] * np.array([np.cos(theta), np.sin(theta)]), + cc1 + scale[n] * + np.array([np.cos(np.pi / 2 + theta), + np.sin(np.pi / 2 + theta)]) + ]) + M = cv2.getAffineTransform(pts1, pts2) + + img = X[n].copy() + if Y is not None: + labels = Y[n].copy() + if labels.ndim < 3: + labels = labels[np.newaxis, :, :] + if feat is not None: + feats = feat[n].copy() + if feats.ndim < 3: + feats = feats[np.newaxis, :, :] + + if do_3D: + Lz = X[n].shape[-3] + flip_z = np.random.rand() > .5 + lz = int(np.round(zcrop / scale[n])) + iz = np.random.randint(0, Lz - lz) + img = img[:,iz:iz + lz,:,:] + if Y is not None: + labels = labels[:,iz:iz + lz,:,:] + if feat is not None: + feats = feats[:,iz:iz + lz,:,:] + + if do_flip: + if flip: + img = img[..., ::-1] + if Y is not None: + labels = labels[..., ::-1] + if nt > 1 and not unet: + labels[-1] = -labels[-1] + if feat is not None: + feats = feats[..., ::-1] + if do_3D and flip_z: + img = img[:, ::-1] + if Y is not None: + labels = labels[:,::-1] + if nt > 1 and not unet: + labels[-3] = -labels[-3] + if feat is not None: + feats = feats[:, ::-1] + + for k in range(nchan): + if do_3D: + img0 = np.zeros((lz, xy[0], xy[1]), "float32") + for z in range(lz): + I = cv2.warpAffine(img[k, z], M, (xy[1], xy[0]), + flags=cv2.INTER_LINEAR) + img0[z] = I + if scale[n] != 1.0: + for y in range(imgi.shape[-2]): + imgi[n, k, :, y] = cv2.resize(img0[:, y], (xy[1], zcrop), + interpolation=cv2.INTER_LINEAR) + else: + imgi[n, k] = img0 + else: + I = cv2.warpAffine(img[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR) + imgi[n, k] = I + + if Y is not None: + for k in range(nt): + flag = cv2.INTER_NEAREST if k < nt-2 else cv2.INTER_LINEAR + if do_3D: + lbl0 = np.zeros((lz, xy[0], xy[1]), "float32") + for z in range(lz): + I = cv2.warpAffine(labels[k, z], M, (xy[1], xy[0]), + flags=flag) + lbl0[z] = I + if scale[n] != 1.0: + for y in range(lbl.shape[-2]): + lbl[n, k, :, y] = cv2.resize(lbl0[:, y], (xy[1], zcrop), + interpolation=flag) + else: + lbl[n, k] = lbl0 + else: + lbl[n, k] = cv2.warpAffine(labels[k], M, (xy[1], xy[0]), flags=flag) + + if nt > 1 and not unet: + v1 = lbl[n, -1].copy() + v2 = lbl[n, -2].copy() + lbl[n, -2] = (-v1 * np.sin(-theta) + v2 * np.cos(-theta)) + lbl[n, -1] = (v1 * np.cos(-theta) + v2 * np.sin(-theta)) + + if feat is not None: + for k in range(nf): + if do_3D: + feat0 = np.zeros((lz, xy[0], xy[1]), "float32") + for z in range(lz): + I = cv2.warpAffine(feats[k, z], M, (xy[1], xy[0]), + flags=cv2.INTER_LINEAR) + feat0[z] = I + if scale[n] != 1.0: + for y in range(feat_out.shape[-2]): + feat_out[n, k, :, y] = cv2.resize(feat0[:, y], (xy[1], zcrop), + interpolation=cv2.INTER_LINEAR) + else: + feat_out[n, k] = feat0 + else: + feat_out[n, k] = cv2.warpAffine(feats[k], M, (xy[1], xy[0]), flags=cv2.INTER_LINEAR) + + + + return imgi, lbl, feat_out, scale diff --git a/models/seg_post_model/cellpose/utils.py b/models/seg_post_model/cellpose/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..4d70767141887af269ddaf024075d39587e5da38 --- /dev/null +++ b/models/seg_post_model/cellpose/utils.py @@ -0,0 +1,667 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +import logging +import os, tempfile, shutil, io +from tqdm import tqdm, trange +from urllib.request import urlopen +import cv2 +from scipy.ndimage import find_objects, gaussian_filter, generate_binary_structure, label +from scipy.spatial import ConvexHull +import numpy as np +import colorsys +import fastremap +import fill_voids +from multiprocessing import Pool, cpu_count +# try: +# from cellpose import metrics +# except: +# import metrics as metrics +from models.seg_post_model.cellpose import metrics + +try: + from skimage.morphology import remove_small_holes + SKIMAGE_ENABLED = True +except: + SKIMAGE_ENABLED = False + + +class TqdmToLogger(io.StringIO): + """ + Output stream for TQDM which will output to logger module instead of + the StdOut. + """ + logger = None + level = None + buf = "" + + def __init__(self, logger, level=None): + super(TqdmToLogger, self).__init__() + self.logger = logger + self.level = level or logging.INFO + + def write(self, buf): + self.buf = buf.strip("\r\n\t ") + + def flush(self): + self.logger.log(self.level, self.buf) + + +def rgb_to_hsv(arr): + rgb_to_hsv_channels = np.vectorize(colorsys.rgb_to_hsv) + r, g, b = np.rollaxis(arr, axis=-1) + h, s, v = rgb_to_hsv_channels(r, g, b) + hsv = np.stack((h, s, v), axis=-1) + return hsv + + +def hsv_to_rgb(arr): + hsv_to_rgb_channels = np.vectorize(colorsys.hsv_to_rgb) + h, s, v = np.rollaxis(arr, axis=-1) + r, g, b = hsv_to_rgb_channels(h, s, v) + rgb = np.stack((r, g, b), axis=-1) + return rgb + + +def download_url_to_file(url, dst, progress=True): + r"""Download object at the given URL to a local path. + Thanks to torch, slightly modified + Args: + url (string): URL of the object to download + dst (string): Full path where object will be saved, e.g. `/tmp/temporary_file` + progress (bool, optional): whether or not to display a progress bar to stderr + Default: True + """ + file_size = None + import ssl + ssl._create_default_https_context = ssl._create_unverified_context + u = urlopen(url) + meta = u.info() + if hasattr(meta, "getheaders"): + content_length = meta.getheaders("Content-Length") + else: + content_length = meta.get_all("Content-Length") + if content_length is not None and len(content_length) > 0: + file_size = int(content_length[0]) + # We deliberately save it in a temp file and move it after + dst = os.path.expanduser(dst) + dst_dir = os.path.dirname(dst) + f = tempfile.NamedTemporaryFile(delete=False, dir=dst_dir) + try: + with tqdm(total=file_size, disable=not progress, unit="B", unit_scale=True, + unit_divisor=1024) as pbar: + while True: + buffer = u.read(8192) + if len(buffer) == 0: + break + f.write(buffer) + pbar.update(len(buffer)) + f.close() + shutil.move(f.name, dst) + finally: + f.close() + if os.path.exists(f.name): + os.remove(f.name) + + +def distance_to_boundary(masks): + """Get the distance to the boundary of mask pixels. + + Args: + masks (int, 2D or 3D array): The masks array. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels. + + Returns: + dist_to_bound (2D or 3D array): The distance to the boundary. Size [Ly x Lx] or [Lz x Ly x Lx]. + + Raises: + ValueError: If the masks array is not 2D or 3D. + + """ + if masks.ndim > 3 or masks.ndim < 2: + raise ValueError("distance_to_boundary takes 2D or 3D array, not %dD array" % + masks.ndim) + dist_to_bound = np.zeros(masks.shape, np.float64) + + if masks.ndim == 3: + for i in range(masks.shape[0]): + dist_to_bound[i] = distance_to_boundary(masks[i]) + return dist_to_bound + else: + slices = find_objects(masks) + for i, si in enumerate(slices): + if si is not None: + sr, sc = si + mask = (masks[sr, sc] == (i + 1)).astype(np.uint8) + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T + ypix, xpix = np.nonzero(mask) + min_dist = ((ypix[:, np.newaxis] - pvr)**2 + + (xpix[:, np.newaxis] - pvc)**2).min(axis=1) + dist_to_bound[ypix + sr.start, xpix + sc.start] = min_dist + return dist_to_bound + + +def masks_to_edges(masks, threshold=1.0): + """Get edges of masks as a 0-1 array. + + Args: + masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels. + threshold (float, optional): Threshold value for distance to boundary. Defaults to 1.0. + + Returns: + edges (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are edge pixels. + """ + dist_to_bound = distance_to_boundary(masks) + edges = (dist_to_bound < threshold) * (masks > 0) + return edges + + +def remove_edge_masks(masks, change_index=True): + """Removes masks with pixels on the edge of the image. + + Args: + masks (int, 2D or 3D array): The masks to be processed. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels. + change_index (bool, optional): If True, after removing masks, changes the indexing so that there are no missing label numbers. Defaults to True. + + Returns: + outlines (2D or 3D array): The processed masks. Size [Ly x Lx] or [Lz x Ly x Lx], where 0 represents no mask and 1, 2, ... represent mask labels. + """ + slices = find_objects(masks.astype(int)) + for i, si in enumerate(slices): + remove = False + if si is not None: + for d, sid in enumerate(si): + if sid.start == 0 or sid.stop == masks.shape[d]: + remove = True + break + if remove: + masks[si][masks[si] == i + 1] = 0 + shape = masks.shape + if change_index: + _, masks = np.unique(masks, return_inverse=True) + masks = np.reshape(masks, shape).astype(np.int32) + + return masks + + +def masks_to_outlines(masks): + """Get outlines of masks as a 0-1 array. + + Args: + masks (int, 2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where 0=NO masks and 1,2,...=mask labels. + + Returns: + outlines (2D or 3D array): Size [Ly x Lx] or [Lz x Ly x Lx], where True pixels are outlines. + """ + if masks.ndim > 3 or masks.ndim < 2: + raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" % + masks.ndim) + outlines = np.zeros(masks.shape, bool) + + if masks.ndim == 3: + for i in range(masks.shape[0]): + outlines[i] = masks_to_outlines(masks[i]) + return outlines + else: + slices = find_objects(masks.astype(int)) + for i, si in enumerate(slices): + if si is not None: + sr, sc = si + mask = (masks[sr, sc] == (i + 1)).astype(np.uint8) + contours = cv2.findContours(mask, cv2.RETR_EXTERNAL, + cv2.CHAIN_APPROX_NONE) + pvc, pvr = np.concatenate(contours[-2], axis=0).squeeze().T + vr, vc = pvr + sr.start, pvc + sc.start + outlines[vr, vc] = 1 + return outlines + + +def outlines_list(masks, multiprocessing_threshold=1000, multiprocessing=None): + """Get outlines of masks as a list to loop over for plotting. + + Args: + masks (ndarray): Array of masks. + multiprocessing_threshold (int, optional): Threshold for enabling multiprocessing. Defaults to 1000. + multiprocessing (bool, optional): Flag to enable multiprocessing. Defaults to None. + + Returns: + list: List of outlines. + + Raises: + None + + Notes: + - This function is a wrapper for outlines_list_single and outlines_list_multi. + - Multiprocessing is disabled for Windows. + """ + # default to use multiprocessing if not few_masks, but allow user to override + if multiprocessing is None: + few_masks = np.max(masks) < multiprocessing_threshold + multiprocessing = not few_masks + + # disable multiprocessing for Windows + if os.name == "nt": + if multiprocessing: + logging.getLogger(__name__).warning( + "Multiprocessing is disabled for Windows") + multiprocessing = False + + if multiprocessing: + return outlines_list_multi(masks) + else: + return outlines_list_single(masks) + + +def outlines_list_single(masks): + """Get outlines of masks as a list to loop over for plotting. + + Args: + masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + + Returns: + list: List of outlines as pixel coordinates. + + """ + outpix = [] + for n in np.unique(masks)[1:]: + mn = masks == n + if mn.sum() > 0: + contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL, + method=cv2.CHAIN_APPROX_NONE) + contours = contours[-2] + cmax = np.argmax([c.shape[0] for c in contours]) + pix = contours[cmax].astype(int).squeeze() + if len(pix) > 4: + outpix.append(pix) + else: + outpix.append(np.zeros((0, 2))) + return outpix + + +def outlines_list_multi(masks, num_processes=None): + """ + Get outlines of masks as a list to loop over for plotting. + + Args: + masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + + Returns: + list: List of outlines as pixel coordinates. + """ + if num_processes is None: + num_processes = cpu_count() + + unique_masks = np.unique(masks)[1:] + with Pool(processes=num_processes) as pool: + outpix = pool.map(get_outline_multi, [(masks, n) for n in unique_masks]) + return outpix + + +def get_outline_multi(args): + """Get the outline of a specific mask in a multi-mask image. + + Args: + args (tuple): A tuple containing the masks and the mask number. + + Returns: + numpy.ndarray: The outline of the specified mask as an array of coordinates. + + """ + masks, n = args + mn = masks == n + if mn.sum() > 0: + contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL, + method=cv2.CHAIN_APPROX_NONE) + contours = contours[-2] + cmax = np.argmax([c.shape[0] for c in contours]) + pix = contours[cmax].astype(int).squeeze() + return pix if len(pix) > 4 else np.zeros((0, 2)) + return np.zeros((0, 2)) + + +def dilate_masks(masks, n_iter=5): + """Dilate masks by n_iter pixels. + + Args: + masks (ndarray): Array of masks. + n_iter (int, optional): Number of pixels to dilate the masks. Defaults to 5. + + Returns: + ndarray: Dilated masks. + """ + dilated_masks = masks.copy() + for n in range(n_iter): + # define the structuring element to use for dilation + kernel = np.ones((3, 3), "uint8") + # find the distance to each mask (distances are zero within masks) + dist_transform = cv2.distanceTransform((dilated_masks == 0).astype("uint8"), + cv2.DIST_L2, 5) + # dilate each mask and assign to it the pixels along the border of the mask + # (does not allow dilation into other masks since dist_transform is zero there) + for i in range(1, np.max(masks) + 1): + mask = (dilated_masks == i).astype("uint8") + dilated_mask = cv2.dilate(mask, kernel, iterations=1) + dilated_mask = np.logical_and(dist_transform < 2, dilated_mask) + dilated_masks[dilated_mask > 0] = i + return dilated_masks + + +def get_perimeter(points): + """ + Calculate the perimeter of a set of points. + + Parameters: + points (ndarray): An array of points with shape (npoints, ndim). + + Returns: + float: The perimeter of the points. + + """ + if points.shape[0] > 4: + points = np.append(points, points[:1], axis=0) + return ((np.diff(points, axis=0)**2).sum(axis=1)**0.5).sum() + else: + return 0 + + +def get_mask_compactness(masks): + """ + Calculate the compactness of masks. + + Parameters: + masks (ndarray): Binary masks representing objects. + + Returns: + ndarray: Array of compactness values for each mask. + """ + perimeters = get_mask_perimeters(masks) + npoints = np.unique(masks, return_counts=True)[1][1:] + areas = npoints + compactness = 4 * np.pi * areas / perimeters**2 + compactness[perimeters == 0] = 0 + compactness[compactness > 1.0] = 1.0 + return compactness + + +def get_mask_perimeters(masks): + """ + Calculate the perimeters of the given masks. + + Parameters: + masks (numpy.ndarray): Binary masks representing objects. + + Returns: + numpy.ndarray: Array containing the perimeters of each mask. + """ + perimeters = np.zeros(masks.max()) + for n in range(masks.max()): + mn = masks == (n + 1) + if mn.sum() > 0: + contours = cv2.findContours(mn.astype(np.uint8), mode=cv2.RETR_EXTERNAL, + method=cv2.CHAIN_APPROX_NONE)[-2] + perimeters[n] = np.array( + [get_perimeter(c.astype(int).squeeze()) for c in contours]).sum() + + return perimeters + + +def circleMask(d0): + """ + Creates an array with indices which are the radius of that x,y point. + + Args: + d0 (tuple): Patch of (-d0, d0+1) over which radius is computed. + + Returns: + tuple: A tuple containing: + - rs (ndarray): Array of radii with shape (2*d0[0]+1, 2*d0[1]+1). + - dx (ndarray): Indices of the patch along the x-axis. + - dy (ndarray): Indices of the patch along the y-axis. + """ + dx = np.tile(np.arange(-d0[1], d0[1] + 1), (2 * d0[0] + 1, 1)) + dy = np.tile(np.arange(-d0[0], d0[0] + 1), (2 * d0[1] + 1, 1)) + dy = dy.transpose() + + rs = (dy**2 + dx**2)**0.5 + return rs, dx, dy + + +def get_mask_stats(masks_true): + """ + Calculate various statistics for the given binary masks. + + Parameters: + masks_true (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + + Returns: + convexity (ndarray): Convexity values for each mask. + solidity (ndarray): Solidity values for each mask. + compactness (ndarray): Compactness values for each mask. + """ + mask_perimeters = get_mask_perimeters(masks_true) + + # disk for compactness + rs, dy, dx = circleMask(np.array([100, 100])) + rsort = np.sort(rs.flatten()) + + # area for solidity + npoints = np.unique(masks_true, return_counts=True)[1][1:] + areas = npoints - mask_perimeters / 2 - 1 + + compactness = np.zeros(masks_true.max()) + convexity = np.zeros(masks_true.max()) + solidity = np.zeros(masks_true.max()) + convex_perimeters = np.zeros(masks_true.max()) + convex_areas = np.zeros(masks_true.max()) + for ic in range(masks_true.max()): + points = np.array(np.nonzero(masks_true == (ic + 1))).T + if len(points) > 15 and mask_perimeters[ic] > 0: + med = np.median(points, axis=0) + # compute compactness of ROI + r2 = ((points - med)**2).sum(axis=1)**0.5 + compactness[ic] = (rsort[:r2.size].mean() + 1e-10) / r2.mean() + try: + hull = ConvexHull(points) + convex_perimeters[ic] = hull.area + convex_areas[ic] = hull.volume + except: + convex_perimeters[ic] = 0 + + convexity[mask_perimeters > 0.0] = (convex_perimeters[mask_perimeters > 0.0] / + mask_perimeters[mask_perimeters > 0.0]) + solidity[convex_areas > 0.0] = (areas[convex_areas > 0.0] / + convex_areas[convex_areas > 0.0]) + convexity = np.clip(convexity, 0.0, 1.0) + solidity = np.clip(solidity, 0.0, 1.0) + compactness = np.clip(compactness, 0.0, 1.0) + return convexity, solidity, compactness + + +def get_masks_unet(output, cell_threshold=0, boundary_threshold=0): + """Create masks using cell probability and cell boundary. + + Args: + output (ndarray): The output array containing cell probability and cell boundary. + cell_threshold (float, optional): The threshold value for cell probability. Defaults to 0. + boundary_threshold (float, optional): The threshold value for cell boundary. Defaults to 0. + + Returns: + ndarray: The masks representing the segmented cells. + + """ + cells = (output[..., 1] - output[..., 0]) > cell_threshold + selem = generate_binary_structure(cells.ndim, connectivity=1) + labels, nlabels = label(cells, selem) + + if output.shape[-1] > 2: + slices = find_objects(labels) + dists = 10000 * np.ones(labels.shape, np.float32) + mins = np.zeros(labels.shape, np.int32) + borders = np.logical_and(~(labels > 0), output[..., 2] > boundary_threshold) + pad = 10 + for i, slc in enumerate(slices): + if slc is not None: + slc_pad = tuple([ + slice(max(0, sli.start - pad), min(labels.shape[j], sli.stop + pad)) + for j, sli in enumerate(slc) + ]) + msk = (labels[slc_pad] == (i + 1)).astype(np.float32) + msk = 1 - gaussian_filter(msk, 5) + dists[slc_pad] = np.minimum(dists[slc_pad], msk) + mins[slc_pad][dists[slc_pad] == msk] = (i + 1) + labels[labels == 0] = borders[labels == 0] * mins[labels == 0] + + masks = labels + shape0 = masks.shape + _, masks = np.unique(masks, return_inverse=True) + masks = np.reshape(masks, shape0) + return masks + + +def stitch3D(masks, stitch_threshold=0.25): + """ + Stitch 2D masks into a 3D volume using a stitch_threshold on IOU. + + Args: + masks (list or ndarray): List of 2D masks. + stitch_threshold (float, optional): Threshold value for stitching. Defaults to 0.25. + + Returns: + list: List of stitched 3D masks. + """ + mmax = masks[0].max() + empty = 0 + for i in trange(len(masks) - 1): + iou = metrics._intersection_over_union(masks[i + 1], masks[i])[1:, 1:] + if not iou.size and empty == 0: + masks[i + 1] = masks[i + 1] + mmax = masks[i + 1].max() + elif not iou.size and not empty == 0: + icount = masks[i + 1].max() + istitch = np.arange(mmax + 1, mmax + icount + 1, 1, masks.dtype) + mmax += icount + istitch = np.append(np.array(0), istitch) + masks[i + 1] = istitch[masks[i + 1]] + else: + iou[iou < stitch_threshold] = 0.0 + iou[iou < iou.max(axis=0)] = 0.0 + istitch = iou.argmax(axis=1) + 1 + ino = np.nonzero(iou.max(axis=1) == 0.0)[0] + istitch[ino] = np.arange(mmax + 1, mmax + len(ino) + 1, 1, masks.dtype) + mmax += len(ino) + istitch = np.append(np.array(0), istitch) + masks[i + 1] = istitch[masks[i + 1]] + empty = 1 + + return masks + + +def diameters(masks): + """ + Calculate the diameters of the objects in the given masks. + + Parameters: + masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + + Returns: + tuple: A tuple containing the median diameter and an array of diameters for each object. + + Examples: + >>> masks = np.array([[0, 1, 1], [1, 0, 0], [1, 1, 0]]) + >>> diameters(masks) + (1.0, array([1.41421356, 1.0, 1.0])) + """ + uniq, counts = fastremap.unique(masks.astype("int32"), return_counts=True) + counts = counts[1:] + md = np.median(counts**0.5) + if np.isnan(md): + md = 0 + md /= (np.pi**0.5) / 2 + return md, counts**0.5 + + +def radius_distribution(masks, bins): + """ + Calculate the radius distribution of masks. + + Args: + masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + bins (int): Number of bins for the histogram. + + Returns: + A tuple containing a normalized histogram of radii, median radius, array of radii. + + """ + unique, counts = np.unique(masks, return_counts=True) + counts = counts[unique != 0] + nb, _ = np.histogram((counts**0.5) * 0.5, bins) + nb = nb.astype(np.float32) + if nb.sum() > 0: + nb = nb / nb.sum() + md = np.median(counts**0.5) * 0.5 + if np.isnan(md): + md = 0 + md /= (np.pi**0.5) / 2 + return nb, md, (counts**0.5) / 2 + + +def size_distribution(masks): + """ + Calculates the size distribution of masks. + + Args: + masks (ndarray): masks (0=no cells, 1=first cell, 2=second cell,...) + + Returns: + float: The ratio of the 25th percentile of mask sizes to the 75th percentile of mask sizes. + """ + counts = np.unique(masks, return_counts=True)[1][1:] + return np.percentile(counts, 25) / np.percentile(counts, 75) + + +def fill_holes_and_remove_small_masks(masks, min_size=15): + """ Fills holes in masks (2D/3D) and discards masks smaller than min_size. + + This function fills holes in each mask using fill_voids.fill. + It also removes masks that are smaller than the specified min_size. + + Parameters: + masks (ndarray): Int, 2D or 3D array of labelled masks. + 0 represents no mask, while positive integers represent mask labels. + The size can be [Ly x Lx] or [Lz x Ly x Lx]. + min_size (int, optional): Minimum number of pixels per mask. + Masks smaller than min_size will be removed. + Set to -1 to turn off this functionality. Default is 15. + + Returns: + ndarray: Int, 2D or 3D array of masks with holes filled and small masks removed. + 0 represents no mask, while positive integers represent mask labels. + The size is [Ly x Lx] or [Lz x Ly x Lx]. + """ + + if masks.ndim > 3 or masks.ndim < 2: + raise ValueError("masks_to_outlines takes 2D or 3D array, not %dD array" % + masks.ndim) + + # Filter small masks + if min_size > 0: + counts = fastremap.unique(masks, return_counts=True)[1][1:] + masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1) + fastremap.renumber(masks, in_place=True) + + slices = find_objects(masks) + j = 0 + for i, slc in enumerate(slices): + if slc is not None: + msk = masks[slc] == (i + 1) + msk = fill_voids.fill(msk) + masks[slc][msk] = (j + 1) + j += 1 + + if min_size > 0: + counts = fastremap.unique(masks, return_counts=True)[1][1:] + masks = fastremap.mask(masks, np.nonzero(counts < min_size)[0] + 1) + fastremap.renumber(masks, in_place=True) + + return masks diff --git a/models/seg_post_model/cellpose/version.py b/models/seg_post_model/cellpose/version.py new file mode 100644 index 0000000000000000000000000000000000000000..97fd6b5fed7146232a9767c36f49519452b398f8 --- /dev/null +++ b/models/seg_post_model/cellpose/version.py @@ -0,0 +1,18 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu. +""" +from importlib.metadata import PackageNotFoundError, version +import sys +from platform import python_version +import torch + +try: + version = version("cellpose") +except PackageNotFoundError: + version = "unknown" + +version_str = f""" +cellpose version: \t{version} +platform: \t{sys.platform} +python version: \t{python_version()} +torch version: \t{torch.__version__}""" diff --git a/models/seg_post_model/cellpose/vit_sam.py b/models/seg_post_model/cellpose/vit_sam.py new file mode 100644 index 0000000000000000000000000000000000000000..bb7867a9552118de93db8a319e51f66fa70c3f87 --- /dev/null +++ b/models/seg_post_model/cellpose/vit_sam.py @@ -0,0 +1,195 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import torch +from segment_anything import sam_model_registry +torch.backends.cuda.matmul.allow_tf32 = True +from torch import nn +import torch.nn.functional as F + +class Transformer(nn.Module): + def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4, + checkpoint=None, dtype=torch.float32): + super(Transformer, self).__init__() + """ + print(self.encoder.patch_embed) + PatchEmbed( + (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16)) + ) + print(self.encoder.neck) + Sequential( + (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) + (1): LayerNorm2d() + (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (3): LayerNorm2d() + ) + """ + # instantiate the vit model, default to not loading SAM + # checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM + self.encoder = sam_model_registry[backbone](checkpoint).image_encoder + w = self.encoder.patch_embed.proj.weight.detach() + nchan = w.shape[0] + + # change token size to ps x ps + self.ps = ps + self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps) + self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps] + + # adjust position embeddings for new bsize and new token size + ds = (1024 // 16) // (bsize // ps) + self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True) + + # readout weights for nout output channels + # if nout is changed, weights will not load correctly from pretrained Cellpose-SAM + self.nout = nout + self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1) + + # W2 reshapes token space to pixel space, not trainable + self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps), + requires_grad=False) + + # fraction of layers to drop at random during training + self.rdrop = rdrop + + # average diameter of ROIs from training images from fine-tuning + self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False) + # average diameter of ROIs during main training + self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False) + + # set attention to global in every layer + for blk in self.encoder.blocks: + blk.window_size = 0 + + self.dtype = dtype + + def forward(self, x, feat=None): + # same progression as SAM until readout + x = self.encoder.patch_embed(x) + if feat is not None: + feat = self.encoder.patch_embed(feat) + x = x + x * feat * 0.5 + + if self.encoder.pos_embed is not None: + x = x + self.encoder.pos_embed + + if self.training and self.rdrop > 0: + nlay = len(self.encoder.blocks) + rdrop = (torch.rand((len(x), nlay), device=x.device) < + torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype) + for i, blk in enumerate(self.encoder.blocks): + mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + x = x * mask + blk(x) * (1-mask) + else: + for blk in self.encoder.blocks: + x = blk(x) + + x = self.encoder.neck(x.permute(0, 3, 1, 2)) + + # readout is changed here + x1 = self.out(x) + x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0) + + # maintain the second output of feature size 256 for backwards compatibility + + return x1, torch.randn((x.shape[0], 256), device=x.device) + + def load_model(self, PATH, device, strict = False): + state_dict = torch.load(PATH, map_location = device, weights_only=True) + keys = [k for k in state_dict.keys()] + if keys[0][:7] == "module.": + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel + new_state_dict[name] = v + self.load_state_dict(new_state_dict, strict = strict) + else: + self.load_state_dict(state_dict, strict = strict) + + if self.dtype != torch.float32: + self = self.to(self.dtype) + + + @property + def device(self): + """ + Get the device of the model. + + Returns: + torch.device: The device of the model. + """ + return next(self.parameters()).device + + def save_model(self, filename): + """ + Save the model to a file. + + Args: + filename (str): The path to the file where the model will be saved. + """ + torch.save(self.state_dict(), filename) + + + +class CPnetBioImageIO(Transformer): + """ + A subclass of the CP-SAM model compatible with the BioImage.IO Spec. + + This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec, + allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo. + """ + + def forward(self, x): + """ + Perform a forward pass of the CPnet model and return unpacked tensors. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + tuple: A tuple containing the output tensor, style tensor, and downsampled tensors. + """ + output_tensor, style_tensor, downsampled_tensors = super().forward(x) + return output_tensor, style_tensor, *downsampled_tensors + + + def load_model(self, filename, device=None): + """ + Load the model from a file. + + Args: + filename (str): The path to the file where the model is saved. + device (torch.device, optional): The device to load the model on. Defaults to None. + """ + if (device is not None) and (device.type != "cpu"): + state_dict = torch.load(filename, map_location=device, weights_only=True) + else: + self.__init__(self.nout) + state_dict = torch.load(filename, map_location=torch.device("cpu"), + weights_only=True) + + self.load_state_dict(state_dict) + + def load_state_dict(self, state_dict): + """ + Load the state dictionary into the model. + + This method overrides the default `load_state_dict` to handle Cellpose's custom + loading mechanism and ensures compatibility with BioImage.IO Core. + + Args: + state_dict (Mapping[str, Any]): A state dictionary to load into the model + """ + if state_dict["output.2.weight"].shape[0] != self.nout: + for name in self.state_dict(): + if "output" not in name: + self.state_dict()[name].copy_(state_dict[name]) + else: + super().load_state_dict( + {name: param for name, param in state_dict.items()}, + strict=False) + + + + diff --git a/models/seg_post_model/cellpose/vit_sam_new.py b/models/seg_post_model/cellpose/vit_sam_new.py new file mode 100644 index 0000000000000000000000000000000000000000..6f4e2da9c825044430543c5cef8302cec0cd5e31 --- /dev/null +++ b/models/seg_post_model/cellpose/vit_sam_new.py @@ -0,0 +1,197 @@ +""" +Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu. +""" + +import torch +from segment_anything import sam_model_registry +torch.backends.cuda.matmul.allow_tf32 = True +from torch import nn +import torch.nn.functional as F + +class Transformer(nn.Module): + def __init__(self, backbone="vit_l", ps=16, nout=3, bsize=256, rdrop=0.4, + checkpoint=None, dtype=torch.float32): + super(Transformer, self).__init__() + """ + print(self.encoder.patch_embed) + PatchEmbed( + (proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16)) + ) + print(self.encoder.neck) + Sequential( + (0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False) + (1): LayerNorm2d() + (2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False) + (3): LayerNorm2d() + ) + """ + # instantiate the vit model, default to not loading SAM + # checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM + if checkpoint is None: + checkpoint = "sam_vit_l_0b3195.pth" + self.encoder = sam_model_registry[backbone](checkpoint).image_encoder + w = self.encoder.patch_embed.proj.weight.detach() + nchan = w.shape[0] + + # change token size to ps x ps + self.ps = ps + # self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps) + # self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps] + + # adjust position embeddings for new bsize and new token size + ds = (1024 // 16) // (bsize // ps) + self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True) + + # readout weights for nout output channels + # if nout is changed, weights will not load correctly from pretrained Cellpose-SAM + self.nout = nout + self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1) + + # W2 reshapes token space to pixel space, not trainable + self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps), + requires_grad=False) + + # fraction of layers to drop at random during training + self.rdrop = rdrop + + # average diameter of ROIs from training images from fine-tuning + self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False) + # average diameter of ROIs during main training + self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False) + + # set attention to global in every layer + for blk in self.encoder.blocks: + blk.window_size = 0 + + self.dtype = dtype + + def forward(self, x, feat=None): + # same progression as SAM until readout + x = self.encoder.patch_embed(x) + if feat is not None: + feat = self.encoder.patch_embed(feat) + x = x + x * feat * 0.5 + + if self.encoder.pos_embed is not None: + x = x + self.encoder.pos_embed + + if self.training and self.rdrop > 0: + nlay = len(self.encoder.blocks) + rdrop = (torch.rand((len(x), nlay), device=x.device) < + torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype) + for i, blk in enumerate(self.encoder.blocks): + mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) + x = x * mask + blk(x) * (1-mask) + else: + for blk in self.encoder.blocks: + x = blk(x) + + x = self.encoder.neck(x.permute(0, 3, 1, 2)) + + # readout is changed here + x1 = self.out(x) + x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0) + + # maintain the second output of feature size 256 for backwards compatibility + + return x1, torch.randn((x.shape[0], 256), device=x.device) + + def load_model(self, PATH, device, strict = False): + state_dict = torch.load(PATH, map_location = device, weights_only=True) + keys = [k for k in state_dict.keys()] + if keys[0][:7] == "module.": + from collections import OrderedDict + new_state_dict = OrderedDict() + for k, v in state_dict.items(): + name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel + new_state_dict[name] = v + self.load_state_dict(new_state_dict, strict = strict) + else: + self.load_state_dict(state_dict, strict = strict) + + if self.dtype != torch.float32: + self = self.to(self.dtype) + + + @property + def device(self): + """ + Get the device of the model. + + Returns: + torch.device: The device of the model. + """ + return next(self.parameters()).device + + def save_model(self, filename): + """ + Save the model to a file. + + Args: + filename (str): The path to the file where the model will be saved. + """ + torch.save(self.state_dict(), filename) + + + +class CPnetBioImageIO(Transformer): + """ + A subclass of the CP-SAM model compatible with the BioImage.IO Spec. + + This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec, + allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo. + """ + + def forward(self, x): + """ + Perform a forward pass of the CPnet model and return unpacked tensors. + + Args: + x (torch.Tensor): Input tensor. + + Returns: + tuple: A tuple containing the output tensor, style tensor, and downsampled tensors. + """ + output_tensor, style_tensor, downsampled_tensors = super().forward(x) + return output_tensor, style_tensor, *downsampled_tensors + + + def load_model(self, filename, device=None): + """ + Load the model from a file. + + Args: + filename (str): The path to the file where the model is saved. + device (torch.device, optional): The device to load the model on. Defaults to None. + """ + if (device is not None) and (device.type != "cpu"): + state_dict = torch.load(filename, map_location=device, weights_only=True) + else: + self.__init__(self.nout) + state_dict = torch.load(filename, map_location=torch.device("cpu"), + weights_only=True) + + self.load_state_dict(state_dict) + + def load_state_dict(self, state_dict): + """ + Load the state dictionary into the model. + + This method overrides the default `load_state_dict` to handle Cellpose's custom + loading mechanism and ensures compatibility with BioImage.IO Core. + + Args: + state_dict (Mapping[str, Any]): A state dictionary to load into the model + """ + if state_dict["output.2.weight"].shape[0] != self.nout: + for name in self.state_dict(): + if "output" not in name: + self.state_dict()[name].copy_(state_dict[name]) + else: + super().load_state_dict( + {name: param for name, param in state_dict.items()}, + strict=False) + + + + diff --git a/models/tra_post_model/trackastra/data/__init__.py b/models/tra_post_model/trackastra/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3e8534a668b518c93e1c7a99ba7ed2fea9148676 --- /dev/null +++ b/models/tra_post_model/trackastra/data/__init__.py @@ -0,0 +1,18 @@ +# ruff: noqa: F401 + +from .augmentations import AugmentationPipeline, RandomCrop +from .data import ( + CTCData, + _ctc_lineages, + # load_ctc_data_from_subfolders, + collate_sequence_padding, + extract_features_regionprops, +) +from .distributed import ( + BalancedBatchSampler, + # BalancedDataModule, + BalancedDistributedSampler, +) +from .example_data import example_data_bacteria, example_data_fluo_3d, example_data_hela, data_hela +from .utils import filter_track_df, load_tiff_timeseries, load_tracklet_links +from .wrfeat import WRFeatures, build_windows, get_features, build_windows_sd diff --git a/models/tra_post_model/trackastra/data/_check_ctc.py b/models/tra_post_model/trackastra/data/_check_ctc.py new file mode 100644 index 0000000000000000000000000000000000000000..078efd78ddf465a2bc1d819193f1ae3d48593fa4 --- /dev/null +++ b/models/tra_post_model/trackastra/data/_check_ctc.py @@ -0,0 +1,114 @@ +import logging + +import numpy as np +import pandas as pd +from skimage.measure import label, regionprops_table + +logger = logging.getLogger(__name__) + + +# from https://github.com/Janelia-Trackathon-2023/traccuracy/blob/main/src/traccuracy/loaders/_ctc.py +def _check_ctc(tracks: pd.DataFrame, detections: pd.DataFrame, masks: np.ndarray): + """Sanity checks for valid CTC format. + + Hard checks (throws exception): + - Tracklet IDs in tracks file must be unique and positive + - Parent tracklet IDs must exist in the tracks file + - Intertracklet edges must be directed forward in time. + - In each time point, the set of segmentation IDs present in the detections must equal the set + of tracklet IDs in the tracks file that overlap this time point. + + Soft checks (prints warning): + - No duplicate tracklet IDs (non-connected pixels with same ID) in a single timepoint. + + Args: + tracks (pd.DataFrame): Tracks in CTC format with columns Cell_ID, Start, End, Parent_ID. + detections (pd.DataFrame): Detections extracted from masks, containing columns + segmentation_id, t. + masks (np.ndarray): Set of masks with time in the first axis. + + Raises: + ValueError: If any of the hard checks fail. + """ + logger.debug("Running CTC format checks") + tracks = tracks.copy() + tracks.columns = ["Cell_ID", "Start", "End", "Parent_ID"] + if tracks["Cell_ID"].min() < 1: + raise ValueError("Cell_IDs in tracks file must be positive integers.") + if len(tracks["Cell_ID"]) < len(tracks["Cell_ID"].unique()): + raise ValueError("Cell_IDs in tracks file must be unique integers.") + + for _, row in tracks.iterrows(): + if row["Parent_ID"] != 0: + if row["Parent_ID"] not in tracks["Cell_ID"].values: + raise ValueError( + f"Parent_ID {row['Parent_ID']} is not present in tracks." + ) + parent_end = tracks[tracks["Cell_ID"] == row["Parent_ID"]]["End"].iloc[0] + if parent_end >= row["Start"]: + raise ValueError( + f"Invalid tracklet connection: Daughter tracklet with ID {row['Cell_ID']} " + f"starts at t={row['Start']}, " + f"but parent tracklet with ID {row['Parent_ID']} only ends at t={parent_end}." + ) + + for t in range(tracks["Start"].min(), tracks["End"].max()): + track_ids = set( + tracks[(tracks["Start"] <= t) & (tracks["End"] >= t)]["Cell_ID"] + ) + det_ids = set(detections[(detections["t"] == t)]["segmentation_id"]) + if not track_ids.issubset(det_ids): + raise ValueError(f"Missing IDs in masks at t={t}: {track_ids - det_ids}") + if not det_ids.issubset(track_ids): + raise ValueError( + f"IDs {det_ids - track_ids} at t={t} not represented in tracks file." + ) + + for t, frame in enumerate(masks): + _, n_components = label(frame, return_num=True) + n_labels = len(detections[detections["t"] == t]) + if n_labels < n_components: + logger.warning(f"{n_components - n_labels} non-connected masks at t={t}.") + + +def _get_node_attributes(masks): + """Calculates x,y,z,t,label for each detection in a movie. + + Args: + masks (np.ndarray): Set of masks with time in the first axis + + Returns: + pd.DataFrame: Dataframe with one detection per row. Columns + segmentation_id, x, y, z, t + """ + data_df = pd.concat([ + _detections_from_image(masks, idx) for idx in range(masks.shape[0]) + ]).reset_index(drop=True) + data_df = data_df.rename( + columns={ + "label": "segmentation_id", + "centroid-2": "z", + "centroid-1": "y", + "centroid-0": "x", + } + ) + data_df["segmentation_id"] = data_df["segmentation_id"].astype(int) + data_df["t"] = data_df["t"].astype(int) + return data_df + + +def _detections_from_image(stack, idx): + """Return the unique track label, centroid and time for each track vertex. + + Args: + stack (np.ndarray): Stack of masks + idx (int): Index of the image to calculate the centroids and track labels + + Returns: + pd.DataFrame: The dataframe of track data for one time step (specified by idx) + """ + props = regionprops_table( + np.asarray(stack[idx, ...]), properties=("label", "centroid") + ) + props["t"] = np.full(props["label"].shape, idx) + return pd.DataFrame(props) diff --git a/models/tra_post_model/trackastra/data/_compute_overlap.py b/models/tra_post_model/trackastra/data/_compute_overlap.py new file mode 100644 index 0000000000000000000000000000000000000000..f49ff132b1b383e944e7f19b1594abc5193eec14 --- /dev/null +++ b/models/tra_post_model/trackastra/data/_compute_overlap.py @@ -0,0 +1,209 @@ +"""Adapted from Fast R-CNN +Written by Sergey Karayev +Licensed under The MIT License +Copyright (c) 2015 Microsoft. +""" + +import numpy as np +from skimage.measure import regionprops + + +def _union_slice(a: tuple[slice], b: tuple[slice]): + """Returns the union of slice tuples a and b.""" + starts = tuple(min(_a.start, _b.start) for _a, _b in zip(a, b)) + stops = tuple(max(_a.stop, _b.stop) for _a, _b in zip(a, b)) + return tuple(slice(start, stop) for start, stop in zip(starts, stops)) + + +def get_labels_with_overlap(gt_frame, res_frame): + """Get all labels IDs in gt_frame and res_frame whose bounding boxes + overlap. + + Args: + gt_frame (np.ndarray): ground truth segmentation for a single frame + res_frame (np.ndarray): result segmentation for a given frame + + Returns: + overlapping_gt_labels: List[int], labels of gt boxes that overlap with res boxes + overlapping_res_labels: List[int], labels of res boxes that overlap with gt boxes + intersections_over_gt: List[float], list of (intersection gt vs res) / (gt area) + """ + gt_frame = gt_frame.astype(np.uint16, copy=False) + res_frame = res_frame.astype(np.uint16, copy=False) + gt_props = regionprops(gt_frame) + gt_boxes = [np.array(gt_prop.bbox) for gt_prop in gt_props] + gt_boxes = np.array(gt_boxes).astype(np.float64) + gt_box_labels = np.asarray( + [int(gt_prop.label) for gt_prop in gt_props], dtype=np.uint16 + ) + + res_props = regionprops(res_frame) + res_boxes = [np.array(res_prop.bbox) for res_prop in res_props] + res_boxes = np.array(res_boxes).astype(np.float64) + res_box_labels = np.asarray( + [int(res_prop.label) for res_prop in res_props], dtype=np.uint16 + ) + if len(gt_props) == 0 or len(res_props) == 0: + return [], [], [] + + if gt_frame.ndim == 3: + overlaps = compute_overlap_3D(gt_boxes, res_boxes) + else: + overlaps = compute_overlap( + gt_boxes, res_boxes + ) # has the form [gt_bbox, res_bbox] + + # Find the bboxes that have overlap at all (ind_ corresponds to box number - starting at 0) + ind_gt, ind_res = np.nonzero(overlaps) + ind_gt = np.asarray(ind_gt, dtype=np.uint16) + ind_res = np.asarray(ind_res, dtype=np.uint16) + overlapping_gt_labels = gt_box_labels[ind_gt] + overlapping_res_labels = res_box_labels[ind_res] + + intersections_over_gt = [] + for i, j in zip(ind_gt, ind_res): + sslice = _union_slice(gt_props[i].slice, res_props[j].slice) + gt_mask = gt_frame[sslice] == gt_box_labels[i] + res_mask = res_frame[sslice] == res_box_labels[j] + area_inter = np.count_nonzero(np.logical_and(gt_mask, res_mask)) + area_gt = np.count_nonzero(gt_mask) + intersections_over_gt.append(area_inter / area_gt) + + return overlapping_gt_labels, overlapping_res_labels, intersections_over_gt + + +def compute_overlap(boxes: np.ndarray, query_boxes: np.ndarray) -> np.ndarray: + """Args: + a: (N, 4) ndarray of float + b: (K, 4) ndarray of float. + + Returns: + overlaps: (N, K) ndarray of overlap between boxes and query_boxes + """ + N = boxes.shape[0] + K = query_boxes.shape[0] + overlaps = np.zeros((N, K), dtype=np.float64) + for k in range(K): + box_area = (query_boxes[k, 2] - query_boxes[k, 0] + 1) * ( + query_boxes[k, 3] - query_boxes[k, 1] + 1 + ) + for n in range(N): + iw = ( + min(boxes[n, 2], query_boxes[k, 2]) + - max(boxes[n, 0], query_boxes[k, 0]) + + 1 + ) + if iw > 0: + ih = ( + min(boxes[n, 3], query_boxes[k, 3]) + - max(boxes[n, 1], query_boxes[k, 1]) + + 1 + ) + if ih > 0: + ua = np.float64( + (boxes[n, 2] - boxes[n, 0] + 1) + * (boxes[n, 3] - boxes[n, 1] + 1) + + box_area + - iw * ih + ) + overlaps[n, k] = iw * ih / ua + return overlaps + + +def compute_overlap_3D(boxes: np.ndarray, query_boxes: np.ndarray) -> np.ndarray: + """Args: + a: (N, 6) ndarray of float + b: (K, 6) ndarray of float. + + Returns: + overlaps: (N, K) ndarray of overlap between boxes and query_boxes + """ + N = boxes.shape[0] + K = query_boxes.shape[0] + overlaps = np.zeros((N, K), dtype=np.float64) + for k in range(K): + box_volume = ( + (query_boxes[k, 3] - query_boxes[k, 0] + 1) + * (query_boxes[k, 4] - query_boxes[k, 1] + 1) + * (query_boxes[k, 5] - query_boxes[k, 2] + 1) + ) + for n in range(N): + id_ = ( + min(boxes[n, 3], query_boxes[k, 3]) + - max(boxes[n, 0], query_boxes[k, 0]) + + 1 + ) + if id_ > 0: + iw = ( + min(boxes[n, 4], query_boxes[k, 4]) + - max(boxes[n, 1], query_boxes[k, 1]) + + 1 + ) + if iw > 0: + ih = ( + min(boxes[n, 5], query_boxes[k, 5]) + - max(boxes[n, 2], query_boxes[k, 2]) + + 1 + ) + if ih > 0: + ua = np.float64( + (boxes[n, 3] - boxes[n, 0] + 1) + * (boxes[n, 4] - boxes[n, 1] + 1) + * (boxes[n, 5] - boxes[n, 2] + 1) + + box_volume + - iw * ih * id_ + ) + overlaps[n, k] = iw * ih * id_ / ua + return overlaps + + +try: + import numba +except ImportError: + import os + import warnings + + if not os.getenv("NO_JIT_WARNING", False): + warnings.warn( + "Numba not installed, falling back to slower numpy implementation. " + "Install numba for a significant speedup. Set the environment " + "variable NO_JIT_WARNING=1 to disable this warning.", + stacklevel=2, + ) +else: + # compute_overlap 2d and 3d have the same signature + signature = [ + "f8[:,::1](f8[:,::1], f8[:,::1])", + numba.types.Array(numba.float64, 2, "C", readonly=True)( + numba.types.Array(numba.float64, 2, "C", readonly=True), + numba.types.Array(numba.float64, 2, "C", readonly=True), + ), + ] + + # variables that appear in the body of each function + common_locals = { + "N": numba.uint64, + "K": numba.uint64, + "overlaps": numba.types.Array(numba.float64, 2, "C"), + "iw": numba.float64, + "ih": numba.float64, + "ua": numba.float64, + "n": numba.uint64, + "k": numba.uint64, + } + + compute_overlap = numba.njit( + signature, + locals={**common_locals, "box_area": numba.float64}, + fastmath=True, + nogil=True, + boundscheck=False, + )(compute_overlap) + + compute_overlap_3D = numba.njit( + signature, + locals={**common_locals, "id_": numba.float64, "box_volume": numba.float64}, + fastmath=True, + nogil=True, + boundscheck=False, + )(compute_overlap_3D) diff --git a/models/tra_post_model/trackastra/data/augmentations.py b/models/tra_post_model/trackastra/data/augmentations.py new file mode 100644 index 0000000000000000000000000000000000000000..14d83e0ed8b9348153d482e35c294d9a4ffe1221 --- /dev/null +++ b/models/tra_post_model/trackastra/data/augmentations.py @@ -0,0 +1,557 @@ +"""#TODO: dont convert to numpy and back to torch.""" + +from collections.abc import Iterable, Sequence +from itertools import chain +from typing import Any + +import kornia.augmentation as K +import numpy as np +import torch +from kornia.augmentation import random_generator as rg +from kornia.augmentation.utils import _range_bound +from kornia.constants import DataKey, Resample +from typing import Optional, Tuple, Sequence, Dict, Union + +def default_augmenter(coords: np.ndarray): + # TODO parametrize magnitude of different augmentations + ndim = coords.shape[1] + + assert coords.ndim == 2 and ndim in (2, 3) + + # first remove offset + center = coords.mean(axis=0, keepdims=True) + + coords = coords - center + + # apply random flip + coords *= 2 * np.random.randint(0, 2, (1, ndim)) - 1 + + # apply rotation along the last two dimensions + phi = np.random.uniform(0, 2 * np.pi) + coords = _rotate(coords, phi, center=None) + + if ndim == 3: + # rotate along the first two dimensions too + phi2, phi3 = np.random.uniform(0, 2 * np.pi, 2) + coords = _rotate(coords, phi2, rot_axis=(0, 1), center=None) + coords = _rotate(coords, phi3, rot_axis=(0, 2), center=None) + + coords += center + + # translation + trans = 128 * np.random.uniform(-1, 1, (1, ndim)) + coords += trans + + # elastic + coords += 1.5 * np.random.normal(0, 1, coords.shape) + + return coords + + +def _rotate( + coords: np.ndarray, phi: float, rot_axis=(-2, -1), center: Optional[Tuple] = None +): + """Rotation along the last two dimensions of coords[..,:-2:].""" + ndim = coords.shape[1] + assert coords.ndim == 2 and ndim in (2, 3) + + if center is None: + center = (0,) * ndim + + assert len(center) == ndim + + center = np.asarray(center) + co, si = np.cos(phi), np.sin(phi) + Rot = np.eye(ndim) + Rot[np.ix_(rot_axis, rot_axis)] = np.array(((co, -si), (si, co))) + x = coords - center + x = x @ Rot.T + x += center + return x + + +def _filter_points( + points: np.ndarray, shape: tuple, origin: Optional[Tuple] = None +) -> np.ndarray: + """Returns indices of points that are inside the shape extent and given origin.""" + ndim = points.shape[-1] + if origin is None: + origin = (0,) * ndim + + idx = tuple( + np.logical_and(points[:, i] >= origin[i], points[:, i] < origin[i] + shape[i]) + for i in range(ndim) + ) + idx = np.where(np.all(idx, axis=0))[0] + return idx + + +class ConcatAffine(K.RandomAffine): + """Concatenate multiple affine transformations without intermediates.""" + + def __init__(self, affines: Sequence[K.RandomAffine]): + super().__init__(degrees=0) + self._affines = affines + if not all([a.same_on_batch for a in affines]): + raise ValueError("all affines must have same_on_batch=True") + + def merge_params(self, params: Sequence[Dict[str, torch.Tensor]]): + """Merge params from affines.""" + out = params[0].copy() + + def _torchmax(x, dim): + return torch.max(x, dim=dim).values + + ops = { + "translations": torch.sum, + "center": torch.mean, + "scale": torch.prod, + "shear_x": torch.sum, + "shear_y": torch.sum, + "angle": torch.sum, + "batch_prob": _torchmax, + } + for k, v in params[0].items(): + ps = [p[k] for p in params if len(p[k]) > 0] + if len(ps) > 0 and k in ops: + v_new = torch.stack(ps, dim=0).float() + v_new = ops[k](v_new, dim=0) + v_new = v_new.to(v.dtype) + else: + v_new = v + out[k] = v_new + + return out + + def forward_parameters( + self, batch_shape: Tuple[int, ...] + ) -> Dict[str, torch.Tensor]: + params = tuple(a.forward_parameters(batch_shape) for a in self._affines) + # print(params) + return self.merge_params(params) + + +# custom augmentations +class RandomIntensityScaleShift(K.IntensityAugmentationBase2D): + r"""Apply a random scale and shift to the image intensity. + + Args: + p: probability of applying the transformation. + scale: the scale factor to apply + shift: the offset to apply + clip_output: if true clip output + same_on_batch: apply the same transformation across the batch. + keepdim: whether to keep the output shape the same as input (True) or broadcast it + to the batch form (False). + Shape: + - Input: :math:`(C, H, W)` or :math:`(B, C, H, W)`, Optional: :math:`(B, 3, 3)` + - Output: :math:`(B, C, H, W)` + + .. note:: + This function internally uses :func:`kornia.enhance.adjust_brightness` + + """ + + def __init__( + self, + scale: Tuple[float, float] = (0.5, 2.0), + shift: Tuple[float, float] = (-0.1, 0.1), + clip_output: bool = True, + same_on_batch: bool = False, + p: float = 1.0, + keepdim: bool = False, + ) -> None: + super().__init__(p=p, same_on_batch=same_on_batch, keepdim=keepdim) + self.scale = _range_bound( + scale, "scale", center=0, bounds=(-float("inf"), float("inf")) + ) + self.shift = _range_bound( + shift, "shift", center=0, bounds=(-float("inf"), float("inf")) + ) + self._param_generator = rg.PlainUniformGenerator( + (self.scale, "scale_factor", None, None), + (self.shift, "shift_factor", None, None), + ) + + self.clip_output = clip_output + + def apply_transform( + self, + input: torch.Tensor, + params: Dict[str, torch.Tensor], + flags: Dict[str, Any], + transform: Optional[torch.Tensor] = None, + ) -> torch.Tensor: + scale_factor = params["scale_factor"].to(input) + shift_factor = params["shift_factor"].to(input) + scale_factor = scale_factor.view(len(scale_factor), 1, 1, 1) + shift_factor = shift_factor.view(len(scale_factor), 1, 1, 1) + img_adjust = input * scale_factor + shift_factor + if self.clip_output: + img_adjust = img_adjust.clamp(min=0.0, max=1.0) + return img_adjust + + +class RandomTemporalAffine(K.RandomAffine): + r"""Apply a random 2D affine transformation to a batch of images while + varying the transformation across the time dimension from 0 to 1. + + Same args/kwargs as K.RandomAffine + + """ + + def __init__(self, *args, **kwargs) -> None: + super().__init__(*args, same_on_batch=True, **kwargs) + + def forward_parameters( + self, batch_shape: Tuple[int, ...] + ) -> Dict[str, torch.Tensor]: + params = super().forward_parameters(batch_shape) + factor = torch.linspace(0, 1, batch_shape[0]).to(params["translations"]) + for key in ["translations", "center", "angle", "shear_x", "shear_y"]: + v = params[key] + if len(v) > 0: + params[key] = v * factor.view(*((-1,) + (1,) * len(v.shape[1:]))) + + for key in [ + "scale", + ]: + v = params[key] + if len(v) > 0: + params[key] = 1 + (v - 1) * factor.view( + *((-1,) + (1,) * len(v.shape[1:])) + ) + return params + + # def compute_transformation(self, input: torch.Tensor, + # params: Dict[str, torch.Tensor], + # flags: Dict[str, Any]) -> torch.Tensor: + # factor = torch.linspace(0, 1, input.shape[0]).to(input) + # for key in ["translations", "center", "angle", "shear_x", "shear_y"]: + # v = params[key] + # params[key] = v * factor.view(*((-1,)+(1,)*len(v.shape[1:]))) + + # for key in ["scale", ]: + # v = params[key] + # params[key] = 1 + (v-1) * factor.view(*((-1,)+(1,)*len(v.shape[1:]))) + + # return super().compute_transformation(input, params, flags) + + +class BasicPipeline: + """transforms img, mask, and points. + + Only supports 2D transformations for now (any 3D object will preserve its z coordinates/dimensions) + """ + + def __init__(self, augs: tuple, filter_points: bool = True): + self.data_keys = ("input", "mask", "keypoints") + self.pipeline = K.AugmentationSequential( + *augs, + # disable align_corners to not trigger lots of warnings from kornia + extra_args={ + DataKey.MASK: {"resample": Resample.NEAREST, "align_corners": False} + }, + data_keys=self.data_keys, + ) + self.filter_points = filter_points + + def __call__( + self, + img: np.ndarray, + mask: np.ndarray, + points: np.ndarray, + timepoints: np.ndarray, + ): + ndim = img.ndim - 1 + assert ( + ndim in (2, 3) + and points.ndim == 2 + and points.shape[-1] == ndim + and timepoints.ndim == 1 + and img.shape == mask.shape + ) + + x = torch.from_numpy(img).float() + y = torch.from_numpy(mask.astype(np.int64)).float() + + # if 2D add dummy channel + if ndim == 2: + x = x.unsqueeze(1) + y = y.unsqueeze(1) + p = points[..., [1, 0]] + # if 3D we use z as channel (i.e. fix augs across z) + elif ndim == 3: + p = points[..., [2, 1]] + + # flip as kornia expects xy and not yx + p = torch.from_numpy(p).unsqueeze(0).float() + # add batch by duplicating to make kornia happy + p = p.expand(len(x), -1, -1) + # create a mask to know which timepoint the points belong to + ts = torch.from_numpy(timepoints).long() + n_points = p.shape[1] + if n_points > 0: + x, y, p = self.pipeline(x, y, p) + else: + # dummy keypoints + x, y = self.pipeline(x, y, torch.zeros((len(x), 1, 2)))[:2] + + # remove batch + p = p[ts, torch.arange(n_points)] + # flip back + p = p[..., [1, 0]] + + # remove channel + if ndim == 2: + x = x.squeeze(1) + y = y.squeeze(1) + + x = x.numpy() + y = y.numpy().astype(np.uint16) + # p = p.squeeze(0).numpy() + p = p.numpy() + # add back z coordinates + if ndim == 3: + p = np.concatenate([points[..., 0:1], p], axis=-1) + ts = ts.numpy() + # remove points outside of img/mask + + if self.filter_points: + idx = _filter_points(p, shape=x.shape[-ndim:]) + + else: + idx = np.arange(len(p), dtype=int) + + p = p[idx] + return (x, y, p), idx + + +class RandomCrop: + def __init__( + self, + crop_size: Optional[Union[int, Tuple[int]]] = None, + ndim: int = 2, + ensure_inside_points: bool = False, + use_padding: bool = True, + padding_mode="constant", + ) -> None: + """crop_size: tuple of int + can be tuple of length 1 (all dimensions) + of length ndim (y,x,...) + of length 2*ndim (y1,y2, x1,x2, ...). + """ + if isinstance(crop_size, int): + crop_size = (crop_size,) * 2 * ndim + elif isinstance(crop_size, Iterable): + pass + else: + raise ValueError(f"{crop_size} has to be int or tuple of int") + + if len(crop_size) == 1: + crop_size = (crop_size[0],) * 2 * ndim + elif len(crop_size) == ndim: + crop_size = tuple(chain(*tuple((c, c) for c in crop_size))) + elif len(crop_size) == 2 * ndim: + pass + else: + raise ValueError(f"crop_size has to be of length 1, {ndim}, or {2 * ndim}") + + crop_size = np.array(crop_size) + self._ndim = ndim + self._crop_bounds = crop_size[::2], crop_size[1::2] + self._use_padding = use_padding + self._ensure_inside_points = ensure_inside_points + self._rng = np.random.RandomState() + self._padding_mode = padding_mode + + def crop_img(self, img: np.ndarray, corner: np.ndarray, crop_size: np.ndarray): + if not img.ndim == self._ndim + 1: + raise ValueError( + f"img has to be 1 (time) + {self._ndim} spatial dimensions" + ) + + pad_left = np.maximum(0, -corner) + pad_right = np.maximum( + 0, corner + crop_size - np.array(img.shape[-self._ndim :]) + ) + + img = np.pad( + img, + ((0, 0), *tuple(np.stack((pad_left, pad_right)).T)), + mode=self._padding_mode, + ) + slices = ( + slice(None), + *tuple(slice(c, c + s) for c, s in zip(corner + pad_left, crop_size)), + ) + return img[slices] + + def crop_points( + self, points: np.ndarray, corner: np.ndarray, crop_size: np.ndarray + ): + idx = _filter_points(points, shape=crop_size, origin=corner) + return points[idx] - corner, idx + + def __call__(self, img: np.ndarray, mask: np.ndarray, points: np.ndarray): + assert ( + img.ndim == self._ndim + 1 + and points.ndim == 2 + and points.shape[-1] == self._ndim + and img.shape == mask.shape + ) + + points = points.astype(int) + + crop_size = self._rng.randint(self._crop_bounds[0], self._crop_bounds[1] + 1) + # print(f'{crop_size=}') + + if self._ensure_inside_points: + if len(points) == 0: + print("No points given, cannot ensure inside points") + return (img, mask, points), np.zeros((0,), int) + + # sample point and corner relative to it + + _idx = np.random.randint(len(points)) + corner = ( + points[_idx] + - crop_size + + 1 + + self._rng.randint(crop_size // 4, 3 * crop_size // 4) + ) + else: + corner = self._rng.randint( + 0, np.maximum(1, np.array(img.shape[-self._ndim :]) - crop_size) + ) + + if not self._use_padding: + corner = np.maximum(0, corner) + crop_size = np.minimum( + crop_size, np.array(img.shape[-self._ndim :]) - corner + ) + + img = self.crop_img(img, corner, crop_size) + mask = self.crop_img(mask, corner, crop_size) + points, idx = self.crop_points(points, corner, crop_size) + + return (img, mask, points), idx + + +class AugmentationPipeline(BasicPipeline): + """transforms img, mask, and points.""" + + def __init__(self, p=0.5, filter_points=True, level=1): + if level == 1: + augs = [ + # Augmentations for all images in a window + K.RandomHorizontalFlip(p=0.5, same_on_batch=True), + K.RandomVerticalFlip(p=0.5, same_on_batch=True), + K.RandomAffine( + degrees=180, + shear=(-10, 10, -10, 10), # x_min, x_max, y_min, y_max + translate=(0.05, 0.05), + scale=(0.8, 1.2), # x_min, x_max, y_min, y_max + p=p, + same_on_batch=True, + ), + K.RandomBrightness( + (0.5, 1.5), clip_output=False, p=p, same_on_batch=True + ), + K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False), + ] + elif level == 2: + # Crafted for DeepCell crop size 256 + augs = [ + # Augmentations for all images in a window + K.RandomHorizontalFlip(p=0.5, same_on_batch=True), + K.RandomVerticalFlip(p=0.5, same_on_batch=True), + K.RandomAffine( + degrees=180, + shear=(-5, 5, -5, 5), # x_min, x_max, y_min, y_max + translate=(0.03, 0.03), + scale=(0.8, 1.2), # isotropic + p=p, + same_on_batch=True, + ), + # Anisotropic scaling + K.RandomAffine( + degrees=0, + scale=(0.9, 1.1, 0.9, 1.1), # x_min, x_max, y_min, y_max + p=p, + same_on_batch=True, + ), + # Independet augmentations for each image in window + K.RandomAffine( + degrees=3, + shear=(-2, 2, -2, 2), # x_min, x_max, y_min, y_max + translate=(0.04, 0.04), + p=p, + same_on_batch=False, + ), + # not implemented for points in kornia 0.7.0 + # K.RandomElasticTransform(alpha=50, sigma=5, p=p, same_on_batch=False), + # Intensity-based augmentations + K.RandomBrightness( + (0.5, 1.5), clip_output=False, p=p, same_on_batch=True + ), + K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False), + ] + elif level == 3: + # Crafted for DeepCell crop size 256 + augs = [ + # Augmentations for all images in a window + K.RandomHorizontalFlip(p=0.5, same_on_batch=True), + K.RandomVerticalFlip(p=0.5, same_on_batch=True), + ConcatAffine([ + K.RandomAffine( + degrees=180, + shear=(-5, 5, -5, 5), # x_min, x_max, y_min, y_max + translate=(0.03, 0.03), + scale=(0.8, 1.2), # isotropic + p=p, + same_on_batch=True, + ), + # Anisotropic scaling + K.RandomAffine( + degrees=0, + scale=(0.9, 1.1, 0.9, 1.1), # x_min, x_max, y_min, y_max + p=p, + same_on_batch=True, + ), + ]), + RandomTemporalAffine( + degrees=10, + translate=(0.05, 0.05), + p=p, + # same_on_batch=True, + ), + # Independet augmentations for each image in window + K.RandomAffine( + degrees=2, + shear=(-2, 2, -2, 2), # x_min, x_max, y_min, y_max + translate=(0.01, 0.01), + p=0.5 * p, + same_on_batch=False, + ), + # Intensity-based augmentations + RandomIntensityScaleShift( + (0.5, 2.0), (-0.1, 0.1), clip_output=False, p=p, same_on_batch=True + ), + K.RandomGaussianNoise(mean=0.0, std=0.03, p=p, same_on_batch=False), + ] + elif level == 4: + # debug + augs = [ + K.RandomAffine( + degrees=30, + shear=(-0, 0, -0, 0), # x_min, x_max, y_min, y_max + translate=(0.0, 0.0), + p=1, + same_on_batch=True, + ), + ] + else: + raise ValueError(f"level {level} not supported") + + super().__init__(augs, filter_points) diff --git a/models/tra_post_model/trackastra/data/data.py b/models/tra_post_model/trackastra/data/data.py new file mode 100644 index 0000000000000000000000000000000000000000..fde0528b97e39b76c35cb35d3a54ac622acb53d6 --- /dev/null +++ b/models/tra_post_model/trackastra/data/data.py @@ -0,0 +1,1509 @@ +import logging +# from collections.abc import Sequence +from pathlib import Path +from timeit import default_timer +from typing import Literal + +import joblib +import lz4.frame +import networkx as nx +import numpy as np +import pandas as pd +import tifffile +import torch +from numba import njit +from scipy import ndimage as ndi +from scipy.spatial.distance import cdist +from skimage.measure import regionprops +from skimage.segmentation import relabel_sequential +from torch.utils.data import Dataset +from tqdm import tqdm + +from . import wrfeat +from ._check_ctc import _check_ctc, _get_node_attributes +from .augmentations import ( + AugmentationPipeline, + RandomCrop, + default_augmenter, +) +from .features import ( + _PROPERTIES, + extract_features_patch, + extract_features_regionprops, +) +from .matching import matching + +from typing import List, Optional, Union, Tuple, Sequence + +# from ..utils import blockwise_sum, normalize +from ..utils import blockwise_sum, normalize + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def _filter_track_df(df, start_frame, end_frame, downscale): + """Only keep tracklets that are present in the given time interval.""" + # only retain cells in interval + df = df[(df.t2 >= start_frame) & (df.t1 < end_frame)] + + # shift start and end of each cell + df.t1 = df.t1 - start_frame + df.t2 = df.t2 - start_frame + # set start/end to min/max + df.t1 = df.t1.clip(0, end_frame - start_frame - 1) + df.t2 = df.t2.clip(0, end_frame - start_frame - 1) + # set all parents to 0 that are not in the interval + df.loc[~df.parent.isin(df.label), "parent"] = 0 + + if downscale > 1: + if start_frame % downscale != 0: + raise ValueError("start_frame must be a multiple of downscale") + + logger.info(f"Temporal downscaling of tracklet links by {downscale}") + + # remove tracklets that have been fully deleted by temporal downsampling + + mask = ( + # (df["t2"] - df["t1"] < downscale - 1) + (df["t1"] % downscale != 0) + & (df["t2"] % downscale != 0) + & (df["t1"] // downscale == df["t2"] // downscale) + ) + logger.info( + f"Remove {mask.sum()} tracklets that are fully deleted by downsampling" + ) + logger.debug(f"Remove {df[mask]}") + + df = df[~mask] + # set parent to 0 if it has been deleted + df.loc[~df.parent.isin(df.label), "parent"] = 0 + + df["t2"] = (df["t2"] / float(downscale)).apply(np.floor).astype(int) + df["t1"] = (df["t1"] / float(downscale)).apply(np.ceil).astype(int) + + # Correct for edge case of single frame tracklet + assert np.all(df["t1"] == np.minimum(df["t1"], df["t2"])) + + return df + + +class _CompressedArray: + """a simple class to compress and decompress a numpy arrays using lz4.""" + + # dont compress float types + def __init__(self, data): + self._data = lz4.frame.compress(data) + self._dtype = data.dtype.type + self._shape = data.shape + + def decompress(self): + s = lz4.frame.decompress(self._data) + data = np.frombuffer(s, dtype=self._dtype).reshape(self._shape) + return data + + +def debug_function(f): + def wrapper(*args, **kwargs): + try: + batch = f(*args, **kwargs) + except Exception as e: + logger.error(f"Error in {f.__name__}: {e}") + return None + logger.info(f"XXXX {len(batch['coords'])}") + return batch + + return wrapper + + +class CTCData(Dataset): + def __init__( + self, + root: str = "", + ndim: int = 2, + use_gt: bool = True, + detection_folders: List[str] = ["TRA"], + window_size: int = 10, + max_tokens: Optional[int] = None, + slice_pct: tuple = (0.0, 1.0), + downscale_spatial: int = 1, + downscale_temporal: int = 1, + augment: int = 0, + features: Literal[ + "none", + "regionprops", + "regionprops2", + "patch", + "patch_regionprops", + "wrfeat", + ] = "wrfeat", + sanity_dist: bool = False, + crop_size: Optional[tuple] = None, + return_dense: bool = False, + compress: bool = False, + **kwargs, + ) -> None: + """_summary_. + + Args: + root (str): + Folder containing the CTC TRA folder. + ndim (int): + Number of dimensions of the data. Defaults to 2d + (if ndim=3 and data is two dimensional, it will be cast to 3D) + detection_folders: + List of relative paths to folder with detections. + Defaults to ["TRA"], which uses the ground truth detections. + window_size (int): + Window size for transformer. + slice_pct (tuple): + Slice the dataset by percentages (from, to). + augment (int): + if 0, no data augmentation. if > 0, defines level of data augmentation. + features (str): + Types of features to use. + sanity_dist (bool): + Use euclidian distance instead of the association matrix as a target. + crop_size (tuple): + Size of the crops to use for augmentation. If None, no cropping is used. + return_dense (bool): + Return dense masks and images in the data samples. + compress (bool): + Compress elements/remove img if not needed to save memory for large datasets + """ + super().__init__() + + self.root = Path(root) + self.name = self.root.name + self.use_gt = use_gt + self.slice_pct = slice_pct + if not 0 <= slice_pct[0] < slice_pct[1] <= 1: + raise ValueError(f"Invalid slice_pct {slice_pct}") + self.downscale_spatial = downscale_spatial + self.downscale_temporal = downscale_temporal + self.detection_folders = detection_folders + self.ndim = ndim + self.features = features + + if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]: + raise ValueError( + f"'{features}' not one of the supported {ndim}D features" + f" {tuple(_PROPERTIES[ndim].keys())}" + ) + + logger.info(f"ROOT (config): \t{self.root}") + self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root) + logger.info(f"ROOT (guessed): \t{self.root}") + logger.info(f"GT TRA (guessed):\t{self.gt_tra_folder}") + if self.use_gt: + self.gt_mask_folder = self._guess_mask_folder(self.root, self.gt_tra_folder) + else: + logger.info("Using dummy masks as GT") + self.gt_mask_folder = self._guess_det_folder( + self.root, self.detection_folders[0] + ) + logger.info(f"GT MASK (guessed):\t{self.gt_mask_folder}") + + # dont load image data if not needed + if features in ("none",): + self.img_folder = None + else: + self.img_folder = self._guess_img_folder(self.root) + logger.info(f"IMG (guessed):\t{self.img_folder}") + + self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs( + ndim, features, augment, crop_size + ) + + if window_size <= 1: + raise ValueError("window must be >1") + self.window_size = window_size + self.max_tokens = max_tokens + + self.slice_pct = slice_pct + self.sanity_dist = sanity_dist + self.return_dense = return_dense + self.compress = compress + self.start_frame = 0 + self.end_frame = None + + start = default_timer() + + if self.features == "wrfeat": + self.windows = self._load_wrfeat() + else: + self.windows = self._load() + + self.n_divs = self._get_ndivs(self.windows) + + if len(self.windows) > 0: + self.ndim = self.windows[0]["coords"].shape[1] + self.n_objects = tuple(len(t["coords"]) for t in self.windows) + logger.info( + f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track" + f" windows from {self.root} ({default_timer() - start:.1f}s)\n" + ) + else: + self.n_objects = 0 + logger.warning(f"Could not load any tracks from {self.root}") + + if self.compress: + self._compress_data() + + # def from_ctc + + @classmethod + def from_arrays(cls, imgs: np.ndarray, masks: np.ndarray, train_args: dict): + self = cls(**train_args) + # for key, value in train_args.items(): + # setattr(self, key, value) + + # self.use_gt = use_gt + # self.slice_pct = slice_pct + # if not 0 <= slice_pct[0] < slice_pct[1] <= 1: + # raise ValueError(f"Invalid slice_pct {slice_pct}") + # self.downscale_spatial = downscale_spatial + # self.downscale_temporal = downscale_temporal + # self.detection_folders = detection_folders + # self.ndim = ndim + # self.features = features + + # if features not in ("none", "wrfeat") and features not in _PROPERTIES[ndim]: + # raise ValueError( + # f"'{features}' not one of the supported {ndim}D features {tuple(_PROPERTIES[ndim].keys())}" + # ) + + # logger.info(f"ROOT (config): {self.root}") + # self.root, self.gt_tra_folder = self._guess_root_and_gt_tra_folder(self.root) + # logger.info(f"ROOT: \t{self.root}") + # logger.info(f"GT TRA:\t{self.gt_tra_folder}") + # if self.use_gt: + # self.gt_mask_folder = self._guess_mask_folder(self.root, self.gt_tra_folder) + # else: + # logger.info("Using dummy masks as GT") + # self.gt_mask_folder = self._guess_det_folder( + # self.root, self.detection_folders[0] + # ) + # logger.info(f"GT MASK:\t{self.gt_mask_folder}") + + # dont load image data if not needed + # if features in ("none",): + # self.img_folder = None + # else: + # self.img_folder = self._guess_img_folder(self.root) + # logger.info(f"IMG:\t\t{self.img_folder}") + + self.feat_dim, self.augmenter, self.cropper = self._setup_features_augs( + self.ndim, self.features, self.augment, self.crop_size + ) + + start = default_timer() + + if self.features == "wrfeat": + self.windows = self._load_wrfeat() + else: + self.windows = self._load() + + self.n_divs = self._get_ndivs(self.windows) + + if len(self.windows) > 0: + self.ndim = self.windows[0]["coords"].shape[1] + self.n_objects = tuple(len(t["coords"]) for t in self.windows) + logger.info( + f"Found {np.sum(self.n_objects)} objects in {len(self.windows)} track" + f" windows from {self.root} ({default_timer() - start:.1f}s)\n" + ) + else: + self.n_objects = 0 + logger.warning(f"Could not load any tracks from {self.root}") + + if self.compress: + self._compress_data() + + def _get_ndivs(self, windows): + n_divs = [] + for w in tqdm(windows, desc="Counting divisions", leave=False): + _n = ( + ( + blockwise_sum( + torch.from_numpy(w["assoc_matrix"]).float(), + torch.from_numpy(w["timepoints"]).long(), + ).max(dim=0)[0] + == 2 + ) + .sum() + .item() + ) + n_divs.append(_n) + return n_divs + + def _setup_features_augs( + self, ndim: int, features: str, augment: int, crop_size: Tuple[int] + ): + if self.features == "wrfeat": + return self._setup_features_augs_wrfeat(ndim, features, augment, crop_size) + + cropper = ( + RandomCrop( + crop_size=crop_size, + ndim=ndim, + use_padding=False, + ensure_inside_points=True, + ) + if crop_size is not None + else None + ) + + # Hack + if self.features == "none": + return 0, default_augmenter, cropper + + if ndim == 2: + augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None + feat_dim = { + "none": 0, + "regionprops": 7, + "regionprops2": 6, + "patch": 256, + "patch_regionprops": 256 + 5, + }[features] + elif ndim == 3: + augmenter = AugmentationPipeline(p=0.8, level=augment) if augment else None + feat_dim = { + "none": 0, + "regionprops2": 11, + "patch_regionprops": 256 + 8, + }[features] + + return feat_dim, augmenter, cropper + + def _compress_data(self): + # compress masks and assoc_matrix + logger.info("Compressing masks and assoc_matrix to save memory") + for w in self.windows: + w["mask"] = _CompressedArray(w["mask"]) + # dont compress full imgs (as needed for patch features) + w["img"] = _CompressedArray(w["img"]) + w["assoc_matrix"] = _CompressedArray(w["assoc_matrix"]) + self.gt_masks = _CompressedArray(self.gt_masks) + self.det_masks = {k: _CompressedArray(v) for k, v in self.det_masks.items()} + # dont compress full imgs (as needed for patch features) + self.imgs = _CompressedArray(self.imgs) + + def _guess_root_and_gt_tra_folder(self, inp: Path): + """Guesses the root and the ground truth folder from a given input path. + + Args: + inp (Path): _description_ + + Returns: + Path: root folder, + """ + if inp.name == "TRA": + # 01_GT/TRA --> 01, 01_GT/TRA + root = inp.parent.parent / inp.parent.name.split("_")[0] + return root, inp + elif "ERR_SEG" in inp.name: + # 01_ERR_SEG --> 01, 01_GT/TRA. We know that the data is in CTC folder format + num = inp.name.split("_")[0] + return inp.parent / num, inp.parent / f"{num}_GT" / "TRA" + else: + ctc_tra = Path(f"{inp}_GT") / "TRA" + tra = ctc_tra if ctc_tra.exists() else inp / "TRA" + # 01 --> 01, 01_GT/TRA or 01/TRA + return inp, tra + + def _guess_img_folder(self, root: Path): + """Guesses the image folder corresponding to a root.""" + if (root / "img").exists(): + return root / "img" + else: + return root + + def _guess_mask_folder(self, root: Path, gt_tra: Path): + """Guesses the mask folder corresponding to a root. + + In CTC format, we use silver truth segmentation masks. + """ + f = None + # first try CTC format + if gt_tra.parent.name.endswith("_GT"): + # We use the silver truth segmentation masks + f = root / str(gt_tra.parent.name).replace("_GT", "_ST") / "SEG" + # try our simpler 'img' format + if f is None or not f.exists(): + f = gt_tra + if not f.exists(): + raise ValueError(f"Could not find mask folder for {root}") + return f + + @classmethod + def _guess_det_folder(cls, root: Path, suffix: str): + """Checks for the annoying CTC format with dataset numbering as part of folder names.""" + guesses = ( + (root / suffix), + Path(f"{root}_{suffix}"), + Path(f"{root}_GT") / suffix, + ) + for path in guesses: + if path.exists(): + return path + + logger.warning(f"Skipping non-existing detection folder {root / suffix}") + return None + + def __len__(self): + return len(self.windows) + + def _load_gt(self): + logger.info("Loading ground truth") + self.start_frame = int( + len(list(self.gt_mask_folder.glob("*.tif"))) * self.slice_pct[0] + ) + self.end_frame = int( + len(list(self.gt_mask_folder.glob("*.tif"))) * self.slice_pct[1] + ) + + masks = self._load_tiffs(self.gt_mask_folder, dtype=np.int32) + masks = self._correct_gt_with_st(self.gt_mask_folder, masks, dtype=np.int32) + + if self.use_gt: + track_df = self._load_tracklet_links(self.gt_tra_folder) + track_df = _filter_track_df( + track_df, self.start_frame, self.end_frame, self.downscale_temporal + ) + else: + # create dummy track dataframe + logger.info("Using dummy track dataframe") + track_df = self._build_tracklets_without_gt(masks) + + _check_ctc(track_df, _get_node_attributes(masks), masks) + + # Build ground truth lineage graph + self.gt_labels, self.gt_timepoints, self.gt_graph = _ctc_lineages( + track_df, masks + ) + + return masks, track_df + + def _correct_gt_with_st( + self, folder: Path, x: np.ndarray, dtype: Optional[str] = None + ): + if str(folder).endswith("_GT/TRA"): + st_path = ( + tuple(folder.parents)[1] + / folder.parent.stem.replace("_GT", "_ST") + / "SEG" + ) + if not st_path.exists(): + logger.debug("No _ST folder found, skipping correction") + else: + logger.info(f"ST MASK:\t\t{st_path} for correcting with ST masks") + st_masks = self._load_tiffs(st_path, dtype) + x = np.maximum(x, st_masks) + + return x + + def _load_tiffs(self, folder: Path, dtype=None): + assert isinstance(self.downscale_temporal, int) + logger.debug(f"Loading tiffs from {folder} as {dtype}") + logger.debug( + f"Temporal downscaling of {folder.name} by {self.downscale_temporal}" + ) + x = np.stack([ + tifffile.imread(f).astype(dtype) + for f in tqdm( + sorted(folder.glob("*.tif"))[ + self.start_frame : self.end_frame : self.downscale_temporal + ], + leave=False, + desc=f"Loading [{self.start_frame}:{self.end_frame}]", + ) + ]) + + # T, (Z), Y, X + assert isinstance(self.downscale_spatial, int) + if self.downscale_spatial > 1 or self.downscale_temporal > 1: + # TODO make safe for label arrays + logger.debug( + f"Spatial downscaling of {folder.name} by {self.downscale_spatial}" + ) + slices = ( + slice(None), + *tuple( + slice(None, None, self.downscale_spatial) for _ in range(x.ndim - 1) + ), + ) + x = x[slices] + + logger.debug(f"Loaded array of shape {x.shape} from {folder}") + return x + + def _masks2properties(self, masks): + """Turn label masks into lists of properties, sorted (ascending) by time and label id. + + Args: + masks (np.ndarray): T, (Z), H, W + + Returns: + labels: List of labels + ts: List of timepoints + coords: List of coordinates + """ + # Get coordinates, timepoints, and labels of detections + labels = [] + ts = [] + coords = [] + properties_by_time = dict() + assert len(self.imgs) == len(masks) + for _t, frame in tqdm( + enumerate(masks), + # total=len(detections), + leave=False, + desc="Loading masks and properties", + ): + regions = regionprops(frame) + t_labels = [] + t_ts = [] + t_coords = [] + for _r in regions: + t_labels.append(_r.label) + t_ts.append(_t) + centroid = np.array(_r.centroid).astype(int) + t_coords.append(centroid) + + properties_by_time[_t] = dict(coords=t_coords, labels=t_labels) + labels.extend(t_labels) + ts.extend(t_ts) + coords.extend(t_coords) + + labels = np.array(labels, dtype=int) + ts = np.array(ts, dtype=int) + coords = np.array(coords, dtype=int) + + return labels, ts, coords, properties_by_time + + def _load_tracklet_links(self, folder: Path) -> pd.DataFrame: + df = pd.read_csv( + folder / "man_track.txt", + delimiter=" ", + names=["label", "t1", "t2", "parent"], + dtype=int, + ) + n_dets = (df.t2 - df.t1 + 1).sum() + logger.debug(f"{folder} has {n_dets} detections") + + n_divs = (df[df.parent != 0]["parent"].value_counts() == 2).sum() + logger.debug(f"{folder} has {n_divs} divisions") + return df + + def _build_tracklets_without_gt(self, masks): + """Create a dataframe with tracklets from masks.""" + rows = [] + for t, m in enumerate(masks): + for c in np.unique(m[m > 0]): + rows.append([c, t, t, 0]) + df = pd.DataFrame(rows, columns=["label", "t1", "t2", "parent"]) + return df + + def _check_dimensions(self, x: np.ndarray): + if self.ndim == 2 and not x.ndim == 3: + raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data") + elif self.ndim == 3: + # if ndim=3 and data is two dimensional, it will be cast to 3D + if x.ndim == 3: + x = np.expand_dims(x, axis=1) + elif x.ndim == 4: + pass + else: + raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data") + return x + + def _load(self): + # Load ground truth + logger.info("Loading ground truth") + self.gt_masks, self.gt_track_df = self._load_gt() + + self.gt_masks = self._check_dimensions(self.gt_masks) + + # Load images + if self.img_folder is None: + self.imgs = np.zeros_like(self.gt_masks) + else: + logger.info("Loading images") + imgs = self._load_tiffs(self.img_folder, dtype=np.float32) + self.imgs = np.stack([ + normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False) + ]) + self.imgs = self._check_dimensions(self.imgs) + if self.compress: + # prepare images to be compressed later (e.g. removing non masked parts for regionprops features) + self.imgs = np.stack([ + _compress_img_mask_preproc(im, mask, self.features) + for im, mask in zip(self.imgs, self.gt_masks) + ]) + + assert len(self.gt_masks) == len(self.imgs) + + # Load each of the detection folders and create data samples with a sliding window + windows = [] + self.properties_by_time = dict() + self.det_masks = dict() + for _f in self.detection_folders: + det_folder = self.root / _f + + if det_folder == self.gt_mask_folder: + det_masks = self.gt_masks + logger.info("DET MASK:\tUsing GT masks") + ( + det_labels, + det_ts, + det_coords, + det_properties_by_time, + ) = self._masks2properties(det_masks) + + det_gt_matching = { + t: {_l: _l for _l in det_properties_by_time[t]["labels"]} + for t in range(len(det_masks)) + } + else: + det_folder = self._guess_det_folder(root=self.root, suffix=_f) + if det_folder is None: + continue + + logger.info(f"DET MASK:\t{det_folder}") + det_masks = self._load_tiffs(det_folder, dtype=np.int32) + det_masks = self._correct_gt_with_st( + det_folder, det_masks, dtype=np.int32 + ) + det_masks = self._check_dimensions(det_masks) + ( + det_labels, + det_ts, + det_coords, + det_properties_by_time, + ) = self._masks2properties(det_masks) + + # FIXME matching can be slow for big images + # raise NotImplementedError("Matching not implemented for 3d version") + det_gt_matching = { + t: { + _d: _gt + for _gt, _d in matching( + self.gt_masks[t], + det_masks[t], + threshold=0.3, + max_distance=16, + ) + } + for t in tqdm(range(len(det_masks)), leave=False, desc="Matching") + } + + self.properties_by_time[_f] = det_properties_by_time + self.det_masks[_f] = det_masks + _w = self._build_windows( + det_folder, + det_masks, + det_labels, + det_ts, + det_coords, + det_gt_matching, + ) + + windows.extend(_w) + + return windows + + def _build_windows( + self, + det_folder, + det_masks, + labels, + ts, + coords, + matching, + ): + """_summary_. + + Args: + det_folder (_type_): _description_ + det_masks (_type_): _description_ + labels (_type_): _description_ + ts (_type_): _description_ + coords (_type_): _description_ + matching (_type_): _description_ + + Raises: + ValueError: _description_ + ValueError: _description_ + + Returns: + _type_: _description_ + """ + window_size = self.window_size + windows = [] + + # Creates the data samples with a sliding window + masks = self.gt_masks + for t1, t2 in tqdm( + zip(range(0, len(masks)), range(window_size, len(masks) + 1)), + total=len(masks) - window_size + 1, + leave=False, + desc="Building windows", + ): + idx = (ts >= t1) & (ts < t2) + _ts = ts[idx] + _labels = labels[idx] + _coords = coords[idx] + + # Use GT + # _labels = self.gt_labels[idx] + # _ts = self.gt_timepoints[idx] + + if len(_labels) == 0: + # raise ValueError(f"No detections in sample {det_folder}:{t1}") + A = np.zeros((0, 0), dtype=bool) + _coords = np.zeros((0, masks.ndim - 1), dtype=int) + else: + if len(np.unique(_ts)) == 1: + logger.debug( + "Only detections from a single timepoint in sample" + f" {det_folder}:{t1}" + ) + + # build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it. + A = _ctc_assoc_matrix( + _labels, + _ts, + self.gt_graph, + matching, + ) + + if self.sanity_dist: + # # Sanity check: Can the model learn the euclidian distances? + # c = coords - coords.mean(axis=0, keepdims=True) + # c /= c.std(axis=0, keepdims=True) + # A = np.einsum('id,jd',c,c) + # A = 1 / (1 + np.exp(-A)) + A = np.exp(-0.01 * cdist(_coords, _coords)) + + w = dict( + coords=_coords, + # TODO imgs and masks are unaltered here + t1=t1, + img=self.imgs[t1:t2], + mask=det_masks[t1:t2], + assoc_matrix=A, + labels=_labels, + timepoints=_ts, + ) + + windows.append(w) + + logger.debug(f"Built {len(windows)} track windows from {det_folder}.\n") + return windows + + def __getitem__(self, n: int, return_dense=None): + # if not set, use default + if self.features == "wrfeat": + return self._getitem_wrfeat(n, return_dense) + + if return_dense is None: + return_dense = self.return_dense + + track = self.windows[n] + coords = track["coords"] + assoc_matrix = track["assoc_matrix"] + labels = track["labels"] + img = track["img"] + mask = track["mask"] + timepoints = track["timepoints"] + min_time = track["t1"] + + if isinstance(mask, _CompressedArray): + mask = mask.decompress() + if isinstance(img, _CompressedArray): + img = img.decompress() + if isinstance(assoc_matrix, _CompressedArray): + assoc_matrix = assoc_matrix.decompress() + + # cropping + if self.cropper is not None: + (img2, mask2, coords2), idx = self.cropper(img, mask, coords) + cropped_timepoints = timepoints[idx] + + # at least one detection in each timepoint to accept the crop + if len(np.unique(cropped_timepoints)) == self.window_size: + # at least two total detections to accept the crop + # if len(idx) >= 2: + img, mask, coords = img2, mask2, coords2 + labels = labels[idx] + timepoints = timepoints[idx] + assoc_matrix = assoc_matrix[idx][:, idx] + else: + logger.debug("disable cropping as no trajectories would be left") + + if self.features == "none": + if self.augmenter is not None: + coords = self.augmenter(coords) + # Empty features + features = np.zeros((len(coords), 0)) + + elif self.features in ("regionprops", "regionprops2"): + if self.augmenter is not None: + (img2, mask2, coords2), idx = self.augmenter( + img, mask, coords, timepoints - min_time + ) + if len(idx) > 0: + img, mask, coords = img2, mask2, coords2 + labels = labels[idx] + timepoints = timepoints[idx] + assoc_matrix = assoc_matrix[idx][:, idx] + mask = mask.astype(int) + else: + logger.debug( + "disable augmentation as no trajectories would be left" + ) + + features = tuple( + extract_features_regionprops( + m, im, labels[timepoints == i + min_time], properties=self.features + ) + for i, (m, im) in enumerate(zip(mask, img)) + ) + features = np.concatenate(features, axis=0) + # features = np.zeros((len(coords), self.feat_dim)) + + elif self.features == "patch": + if self.augmenter is not None: + (img2, mask2, coords2), idx = self.augmenter( + img, mask, coords, timepoints - min_time + ) + if len(idx) > 0: + img, mask, coords = img2, mask2, coords2 + labels = labels[idx] + timepoints = timepoints[idx] + assoc_matrix = assoc_matrix[idx][:, idx] + mask = mask.astype(int) + else: + print("disable augmentation as no trajectories would be left") + + features = tuple( + extract_features_patch( + m, + im, + coords[timepoints == min_time + i], + labels[timepoints == min_time + i], + ) + for i, (m, im) in enumerate(zip(mask, img)) + ) + features = np.concatenate(features, axis=0) + elif self.features == "patch_regionprops": + if self.augmenter is not None: + (img2, mask2, coords2), idx = self.augmenter( + img, mask, coords, timepoints - min_time + ) + if len(idx) > 0: + img, mask, coords = img2, mask2, coords2 + labels = labels[idx] + timepoints = timepoints[idx] + assoc_matrix = assoc_matrix[idx][:, idx] + mask = mask.astype(int) + else: + print("disable augmentation as no trajectories would be left") + + features1 = tuple( + extract_features_patch( + m, + im, + coords[timepoints == min_time + i], + labels[timepoints == min_time + i], + ) + for i, (m, im) in enumerate(zip(mask, img)) + ) + features2 = tuple( + extract_features_regionprops( + m, + im, + labels[timepoints == i + min_time], + properties=self.features, + ) + for i, (m, im) in enumerate(zip(mask, img)) + ) + + features = tuple( + np.concatenate((f1, f2), axis=-1) + for f1, f2 in zip(features1, features2) + ) + + features = np.concatenate(features, axis=0) + + # remove temporal offset and add timepoints to coords + relative_timepoints = timepoints - track["t1"] + coords = np.concatenate((relative_timepoints[:, None], coords), axis=-1) + + if self.max_tokens and len(timepoints) > self.max_tokens: + time_incs = np.where(timepoints - np.roll(timepoints, 1))[0] + n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1] + timepoints = timepoints[:n_elems] + labels = labels[:n_elems] + coords = coords[:n_elems] + features = features[:n_elems] + assoc_matrix = assoc_matrix[:n_elems, :n_elems] + logger.info( + f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}" + ) + + coords0 = torch.from_numpy(coords).float() + features = torch.from_numpy(features).float() + assoc_matrix = torch.from_numpy(assoc_matrix.copy()).float() + labels = torch.from_numpy(labels).long() + timepoints = torch.from_numpy(timepoints).long() + + if self.augmenter is not None: + coords = coords0.clone() + coords[:, 1:] += torch.randint(0, 256, (1, self.ndim)) + else: + coords = coords0.clone() + res = dict( + features=features, + coords0=coords0, + coords=coords, + assoc_matrix=assoc_matrix, + timepoints=timepoints, + labels=labels, + ) + + if return_dense: + if all([x is not None for x in img]): + img = torch.from_numpy(img).float() + res["img"] = img + + mask = torch.from_numpy(mask.astype(int)).long() + res["mask"] = mask + + return res + + # wrfeat functions... + # TODO: refactor this as a subclass or make everything a class factory. *very* hacky this way + + def _setup_features_augs_wrfeat( + self, ndim: int, features: str, augment: int, crop_size: Tuple[int] + ): + # FIXME: hardcoded + feat_dim = 7 if ndim == 2 else 12 + if augment == 1: + augmenter = wrfeat.WRAugmentationPipeline([ + wrfeat.WRRandomFlip(p=0.5), + wrfeat.WRRandomAffine( + p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1) + ), + # wrfeat.WRRandomBrightness(p=0.8, factor=(0.5, 2.0)), + # wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)), + ]) + elif augment == 2: + augmenter = wrfeat.WRAugmentationPipeline([ + wrfeat.WRRandomFlip(p=0.5), + wrfeat.WRRandomAffine( + p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1) + ), + wrfeat.WRRandomBrightness(p=0.8), + wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)), + ]) + elif augment == 3: + augmenter = wrfeat.WRAugmentationPipeline([ + wrfeat.WRRandomFlip(p=0.5), + wrfeat.WRRandomAffine( + p=0.8, degrees=180, scale=(0.5, 2), shear=(0.1, 0.1) + ), + wrfeat.WRRandomBrightness(p=0.8), + wrfeat.WRRandomMovement(offset=(-10, 10), p=0.3), + wrfeat.WRRandomOffset(p=0.8, offset=(-3, 3)), + ]) + else: + augmenter = None + + cropper = ( + wrfeat.WRRandomCrop( + crop_size=crop_size, + ndim=ndim, + ) + if crop_size is not None + else None + ) + return feat_dim, augmenter, cropper + + def _load_wrfeat(self): + # Load ground truth + self.gt_masks, self.gt_track_df = self._load_gt() + self.gt_masks = self._check_dimensions(self.gt_masks) + + # Load images + if self.img_folder is None: + if self.gt_masks is not None: + self.imgs = np.zeros_like(self.gt_masks) + else: + raise NotImplementedError("No images and no GT masks") + else: + logger.info("Loading images") + imgs = self._load_tiffs(self.img_folder, dtype=np.float32) + self.imgs = np.stack([ + normalize(_x) for _x in tqdm(imgs, desc="Normalizing", leave=False) + ]) + self.imgs = self._check_dimensions(self.imgs) + if self.compress: + # prepare images to be compressed later (e.g. removing non masked parts for regionprops features) + self.imgs = np.stack([ + _compress_img_mask_preproc(im, mask, self.features) + for im, mask in zip(self.imgs, self.gt_masks) + ]) + + assert len(self.gt_masks) == len(self.imgs) + + # Load each of the detection folders and create data samples with a sliding window + windows = [] + self.properties_by_time = dict() + self.det_masks = dict() + logger.info("Loading detections") + for _f in self.detection_folders: + det_folder = self.root / _f + + if det_folder == self.gt_mask_folder: + det_masks = self.gt_masks + logger.info("DET MASK:\tUsing GT masks") + # identity matching + det_gt_matching = { + t: {_l: _l for _l in set(np.unique(d)) - {0}} + for t, d in enumerate(det_masks) + } + else: + det_folder = self._guess_det_folder(root=self.root, suffix=_f) + if det_folder is None: + continue + logger.info(f"DET MASK (guessed):\t{det_folder}") + det_masks = self._load_tiffs(det_folder, dtype=np.int32) + det_masks = self._correct_gt_with_st( + det_folder, det_masks, dtype=np.int32 + ) + det_masks = self._check_dimensions(det_masks) + # FIXME matching can be slow for big images + # raise NotImplementedError("Matching not implemented for 3d version") + det_gt_matching = { + t: { + _d: _gt + for _gt, _d in matching( + self.gt_masks[t], + det_masks[t], + threshold=0.3, + max_distance=16, + ) + } + for t in tqdm(range(len(det_masks)), leave=False, desc="Matching") + } + + self.det_masks[_f] = det_masks + + # build features + + features = joblib.Parallel(n_jobs=8)( + joblib.delayed(wrfeat.WRFeatures.from_mask_img)( + mask=mask[None], img=img[None], t_start=t + ) + for t, (mask, img) in enumerate(zip(det_masks, self.imgs)) + ) + + properties_by_time = dict() + for _t, _feats in enumerate(features): + properties_by_time[_t] = dict( + coords=_feats.coords, labels=_feats.labels + ) + self.properties_by_time[_f] = properties_by_time + + _w = self._build_windows_wrfeat( + features, + det_masks, + det_gt_matching, + ) + + windows.extend(_w) + + return windows + + def _build_windows_wrfeat( + self, + features: Sequence[wrfeat.WRFeatures], + det_masks: np.ndarray, + matching: Tuple[dict], + ): + assert len(self.imgs) == len(det_masks) + + window_size = self.window_size + windows = [] + + # Creates the data samples with a sliding window + for t1, t2 in tqdm( + zip(range(0, len(det_masks)), range(window_size, len(det_masks) + 1)), + total=len(det_masks) - window_size + 1, + leave=False, + desc="Building windows", + ): + img = self.imgs[t1:t2] + mask = det_masks[t1:t2] + feat = wrfeat.WRFeatures.concat(features[t1:t2]) + + labels = feat.labels + timepoints = feat.timepoints + coords = feat.coords + + if len(feat) == 0: + A = np.zeros((0, 0), dtype=bool) + coords = np.zeros((0, feat.ndim), dtype=int) + else: + # build matrix from incomplete labels, but full lineage graph. If a label is missing, I should skip over it. + A = _ctc_assoc_matrix( + labels, + timepoints, + self.gt_graph, + matching, + ) + w = dict( + coords=coords, + # TODO imgs and masks are unaltered here + t1=t1, + img=img, + mask=mask, + assoc_matrix=A, + labels=labels, + timepoints=timepoints, + wrfeat=feat, + ) + windows.append(w) + + logger.debug(f"Built {len(windows)} track windows.\n") + return windows + + def _getitem_wrfeat(self, n: int, return_dense=None): + # if not set, use default + + if return_dense is None: + return_dense = self.return_dense + + track = self.windows[n] + # coords = track["coords"] + assoc_matrix = track["assoc_matrix"] + labels = track["labels"] + img = track["img"] + mask = track["mask"] + timepoints = track["timepoints"] + # track["t1"] + feat = track["wrfeat"] + + if return_dense and isinstance(mask, _CompressedArray): + mask = mask.decompress() + if return_dense and isinstance(img, _CompressedArray): + img = img.decompress() + if isinstance(assoc_matrix, _CompressedArray): + assoc_matrix = assoc_matrix.decompress() + + # cropping + if self.cropper is not None: + # Use only if there is at least one timepoint per detection + cropped_feat, cropped_idx = self.cropper(feat) + cropped_timepoints = timepoints[cropped_idx] + if len(np.unique(cropped_timepoints)) == self.window_size: + idx = cropped_idx + feat = cropped_feat + labels = labels[idx] + timepoints = timepoints[idx] + assoc_matrix = assoc_matrix[idx][:, idx] + else: + logger.debug("Skipping cropping") + + if self.augmenter is not None: + feat = self.augmenter(feat) + + coords0 = np.concatenate((feat.timepoints[:, None], feat.coords), axis=-1) + coords0 = torch.from_numpy(coords0).float() + assoc_matrix = torch.from_numpy(assoc_matrix.astype(np.float32)) + features = torch.from_numpy(feat.features_stacked).float() + labels = torch.from_numpy(feat.labels).long() + timepoints = torch.from_numpy(feat.timepoints).long() + + if self.max_tokens and len(timepoints) > self.max_tokens: + time_incs = np.where(timepoints - np.roll(timepoints, 1))[0] + n_elems = time_incs[np.searchsorted(time_incs, self.max_tokens) - 1] + timepoints = timepoints[:n_elems] + labels = labels[:n_elems] + coords0 = coords0[:n_elems] + features = features[:n_elems] + assoc_matrix = assoc_matrix[:n_elems, :n_elems] + logger.debug( + f"Clipped window of size {timepoints[n_elems - 1] - timepoints.min()}" + ) + + if self.augmenter is not None: + coords = coords0.clone() + coords[:, 1:] += torch.randint(0, 512, (1, self.ndim)) + else: + coords = coords0.clone() + res = dict( + features=features, + coords0=coords0, + coords=coords, + assoc_matrix=assoc_matrix, + timepoints=timepoints, + labels=labels, + ) + + if return_dense: + if all([x is not None for x in img]): + img = torch.from_numpy(img).float() + res["img"] = img + + mask = torch.from_numpy(mask.astype(int)).long() + res["mask"] = mask + + return res + + +def _ctc_lineages(df, masks, t1=0, t2=None): + """From a ctc dataframe, create a digraph that contains all sublineages + between t1 and t2 (exclusive t2). + + Args: + df: pd.DataFrame with columns `label`, `t1`, `t2`, `parent` (man_track.txt) + masks: List of masks. If t1 is not 0, then the masks are assumed to be already cropped accordingly. + t1: Start timepoint + t2: End timepoint (exclusive). If None, then t2 is set to len(masks) + + Returns: + labels: List of label ids extracted from the masks, ordered by timepoint. + ts: List of corresponding timepoints + graph: The digraph of the lineages between t1 and t2. + """ + if t1 > 0: + assert t2 is not None + assert t2 - t1 == len(masks) + if t2 is None: + t2 = len(masks) + + graph = nx.DiGraph() + labels = [] + ts = [] + + # get all objects that are present in the time interval + df_sub = df[(df.t1 < t2) & (df.t2 >= t1)] + + # Correct offset + df_sub.loc[:, "t1"] -= t1 + df_sub.loc[:, "t2"] -= t1 + + # all_labels = df_sub.label.unique() + # TODO speed up by precalculating unique values once + # in_masks = set(np.where(np.bincount(np.stack(masks[t1:t2]).ravel()))[0]) - {0} + # all_labels = [l for l in all_labels if l in in_masks] + all_labels = set() + + for t in tqdm( + range(0, t2 - t1), desc="Building and checking lineage graph", leave=False + ): + # get all entities at timepoint + obs = df_sub[(df_sub.t1 <= t) & (df_sub.t2 >= t)] + in_t = set(np.where(np.bincount(masks[t].ravel()))[0]) - {0} + all_labels.update(in_t) + for row in obs.itertuples(): + label, t1, t2, parent = row.label, row.t1, row.t2, row.parent + if label not in in_t: + continue + + labels.append(label) + ts.append(t) + + # add label as node if not already in graph + if not graph.has_node(label): + graph.add_node(label) + + # Parents have been added in previous timepoints + if parent in all_labels: + if not graph.has_node(parent): + graph.add_node(parent) + graph.add_edge(parent, label) + + labels = np.array(labels) + ts = np.array(ts) + return labels, ts, graph + + +@njit +def _assoc(A: np.ndarray, labels: np.ndarray, family: np.ndarray): + """For each detection, associate with all detections that are.""" + for i in range(len(labels)): + for j in range(len(labels)): + A[i, j] = family[i, labels[j]] + + +def _ctc_assoc_matrix(detections, ts, graph, matching): + """Create the association matrix for a list of labels and a tracklet parent -> childrend graph. + + Each detection is associated with all its ancestors and descendants, but not its siblings and their offspring. + + Args: + detections: list of integer labels, ordered by timepoint + ts: list of timepoints corresponding to the detections + graph: networkx DiGraph with each ground truth tracklet id (spanning n timepoints) as a single node + and parent -> children relationships as edges. + matching: for each timepoint, a dictionary that maps from detection id to gt tracklet id + """ + assert 0 not in graph + matched_gt = [] + for i, (label, t) in enumerate(zip(detections, ts)): + gt_tracklet_id = matching[t].get(label, 0) + matched_gt.append(gt_tracklet_id) + matched_gt = np.array(matched_gt, dtype=int) + # Now we have the subset of gt nodes that is matched to any detection in the current window + + # relabel to reduce the size of lookup matrices + # offset 0 not allowed in skimage, which makes this very annoying + relabeled_gt, fwd_map, _inv_map = relabel_sequential(matched_gt, offset=1) + # dict is faster than arraymap + fwd_map = dict(zip(fwd_map.in_values, fwd_map.out_values)) + # inv_map = dict(zip(inv_map.in_values, inv_map.out_values)) + + # the family relationships for each ground truth detection, + # Maps from local detection number (0-indexed) to global gt tracklet id (1-indexed) + family = np.zeros((len(detections), len(relabeled_gt) + 1), bool) + + # Connects each tracklet id with its children and parent tracklets (according to man_track.txt) + for i, (label, t) in enumerate(zip(detections, ts)): + # Get the original label corresponding to the graph + gt_tracklet_id = matching[t].get(label, None) + if gt_tracklet_id is not None: + ancestors = [] + descendants = [] + # This iterates recursively through the graph + for n in nx.descendants(graph, gt_tracklet_id): + if n in fwd_map: + descendants.append(fwd_map[n]) + for n in nx.ancestors(graph, gt_tracklet_id): + if n in fwd_map: + ancestors.append(fwd_map[n]) + + family[i, np.array([fwd_map[gt_tracklet_id], *ancestors, *descendants])] = ( + True + ) + else: + pass + # Now we match to nothing, so even the matrix diagonal will not be filled. + + # This assures that matching to 0 is always false + assert family[:, 0].sum() == 0 + + # Create the detection-to-detection association matrix + A = np.zeros((len(detections), len(detections)), dtype=bool) + + _assoc(A, relabeled_gt, family) + + return A + + +def sigmoid(x): + return 1 / (1 + np.exp(-x)) + + +def _compress_img_mask_preproc(img, mask, features): + """Remove certain img pixels if not needed to save memory for large datasets.""" + # dont change anything if we need patch values + if features in ("patch", "patch_regionprops"): + # clear img pixels outside of patch_mask of size 16x16 + patch_width = 16 # TOD: hardcoded: change this if needed + coords = tuple(np.array(r.centroid).astype(int) for r in regionprops(mask)) + img2 = np.zeros_like(img) + if len(coords) > 0: + coords = np.stack(coords) + coords = np.clip(coords, 0, np.array(mask.shape)[None] - 1) + patch_mask = np.zeros_like(img, dtype=bool) + patch_mask[tuple(coords.T)] = True + # retain 3*patch_width+1 around center to be safe... + patch_mask = ndi.maximum_filter(patch_mask, 3 * patch_width + 1) + img2[patch_mask] = img[patch_mask] + + else: + # otherwise set img value inside masks to mean + # FIXME: change when using other intensity based regionprops + img2 = np.zeros_like(img) + for reg in regionprops(mask, intensity_image=img): + m = mask[reg.slice] == reg.label + img2[reg.slice][m] = reg.mean_intensity + return img2 + + +def pad_tensor(x, n_max: int, dim=0, value=0): + n = x.shape[dim] + if n_max < n: + raise ValueError(f"pad_tensor: n_max={n_max} must be larger than n={n} !") + pad_shape = list(x.shape) + pad_shape[dim] = n_max - n + # pad = torch.full(pad_shape, fill_value=value, dtype=x.dtype).to(x.device) + pad = torch.full(pad_shape, fill_value=value, dtype=x.dtype) + return torch.cat((x, pad), dim=dim) + + +def collate_sequence_padding(batch): + """Collate function that pads all sequences to the same length.""" + lens = tuple(len(x["coords"]) for x in batch) + n_max_len = max(lens) + # print(tuple(len(x["coords"]) for x in batch)) + # print(tuple(len(x["features"]) for x in batch)) + # print(batch[0].keys()) + tuple(batch[0].keys()) + normal_keys = { + "coords": 0, + "features": 0, + "labels": 0, # Not needed, remove for speed. + "timepoints": -1, # There are real timepoints with t=0. -1 for distinction from that. + } + n_pads = tuple(n_max_len - s for s in lens) + batch_new = dict( + ( + k, + torch.stack( + [pad_tensor(x[k], n_max=n_max_len, value=v) for x in batch], dim=0 + ), + ) + for k, v in normal_keys.items() + ) + batch_new["assoc_matrix"] = torch.stack( + [ + pad_tensor( + pad_tensor(x["assoc_matrix"], n_max_len, dim=0), n_max_len, dim=1 + ) + for x in batch + ], + dim=0, + ) + + # add boolean mask that signifies whether tokens are padded or not (such that they can be ignored later) + pad_mask = torch.zeros((len(batch), n_max_len), dtype=torch.bool) + for i, n_pad in enumerate(n_pads): + pad_mask[i, n_max_len - n_pad :] = True + + batch_new["padding_mask"] = pad_mask.bool() + return batch_new + + +if __name__ == "__main__": + dummy_data = CTCData( + root="../../scripts/data/synthetic_cells/01", + ndim=2, + detection_folders=["TRA"], + window_size=4, + max_tokens=None, + augment=3, + features="none", + downscale_temporal=1, + downscale_spatial=1, + sanity_dist=False, + crop_size=(256, 256), + ) + + x = dummy_data[0] diff --git a/models/tra_post_model/trackastra/data/distributed.py b/models/tra_post_model/trackastra/data/distributed.py new file mode 100644 index 0000000000000000000000000000000000000000..2134fcf71e2cfd998e3385d6ee9428278d8aaff5 --- /dev/null +++ b/models/tra_post_model/trackastra/data/distributed.py @@ -0,0 +1,316 @@ +"""Data loading and sampling utils for distributed training.""" + +import hashlib +import json +import logging +import pickle +# from collections.abc import Iterable +from copy import deepcopy +from pathlib import Path +from timeit import default_timer + +import numpy as np +import torch +# from lightning import LightningDataModule +from torch.utils.data import ( + BatchSampler, + ConcatDataset, + DataLoader, + Dataset, + DistributedSampler, +) +from typing import Optional, Iterable +from .data import CTCData + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def cache_class(cachedir=None): + """A simple file cache for CTCData.""" + + def make_hashable(obj): + if isinstance(obj, tuple | list): + return tuple(make_hashable(e) for e in obj) + elif isinstance(obj, Path): + return obj.as_posix() + elif isinstance(obj, dict): + return tuple(sorted((k, make_hashable(v)) for k, v in obj.items())) + else: + return obj + + def hash_args_kwargs(*args, **kwargs): + hashable_args = tuple(make_hashable(arg) for arg in args) + hashable_kwargs = make_hashable(kwargs) + combined_serialized = json.dumps( + [hashable_args, hashable_kwargs], sort_keys=True + ) + hash_obj = hashlib.sha256(combined_serialized.encode()) + return hash_obj.hexdigest() + + if cachedir is None: + return CTCData + else: + cachedir = Path(cachedir) + + def _wrapped(*args, **kwargs): + h = hash_args_kwargs(*args, **kwargs) + cachedir.mkdir(exist_ok=True, parents=True) + cache_file = cachedir / f"{h}.pkl" + if cache_file.exists(): + logger.info(f"Loading cached dataset from {cache_file}") + with open(cache_file, "rb") as f: + return pickle.load(f) + else: + c = CTCData(*args, **kwargs) + logger.info(f"Saving cached dataset to {cache_file}") + pickle.dump(c, open(cache_file, "wb")) + return c + + return _wrapped + + +class BalancedBatchSampler(BatchSampler): + """samples batch indices such that the number of objects in each batch is balanced + (so to reduce the number of paddings in the batch). + + + """ + + def __init__( + self, + dataset: torch.utils.data.Dataset, + batch_size: int, + n_pool: int = 10, + num_samples: Optional[int] = None, + weight_by_ndivs: bool = False, + weight_by_dataset: bool = False, + drop_last: bool = False, + ): + """Setting n_pool =1 will result in a regular random batch sampler. + + weight_by_ndivs: if True, the probability of sampling an element is proportional to the number of divisions + weight_by_dataset: if True, the probability of sampling an element is inversely proportional to the length of the dataset + """ + if isinstance(dataset, CTCData): + self.n_objects = dataset.n_objects + self.n_divs = np.array(dataset.n_divs) + self.n_sizes = np.ones(len(dataset)) * len(dataset) + elif isinstance(dataset, ConcatDataset): + self.n_objects = tuple(n for d in dataset.datasets for n in d.n_objects) + self.n_divs = np.array(tuple(n for d in dataset.datasets for n in d.n_divs)) + self.n_sizes = np.array( + tuple(len(d) for d in dataset.datasets for _ in range(len(d))) + ) + else: + raise NotImplementedError( + f"BalancedBatchSampler: Unknown dataset type {type(dataset)}" + ) + assert len(self.n_objects) == len(self.n_divs) == len(self.n_sizes) + + self.batch_size = batch_size + self.n_pool = n_pool + self.drop_last = drop_last + self.num_samples = num_samples + self.weight_by_ndivs = weight_by_ndivs + self.weight_by_dataset = weight_by_dataset + logger.debug(f"{weight_by_ndivs=}") + logger.debug(f"{weight_by_dataset=}") + + def get_probs(self, idx): + idx = np.array(idx) + if self.weight_by_ndivs: + probs = 1 + np.sqrt(self.n_divs[idx]) + else: + probs = np.ones(len(idx)) + if self.weight_by_dataset: + probs = probs / (self.n_sizes[idx] + 1e-6) + + probs = probs / (probs.sum() + 1e-10) + return probs + + def sample_batches(self, idx: Iterable[int]): + # we will split the indices into pools of size n_pool + num_samples = self.num_samples if self.num_samples is not None else len(idx) + # sample from the indices with replacement and given probabilites + idx = np.random.choice(idx, num_samples, replace=True, p=self.get_probs(idx)) + + n_pool = min( + self.n_pool * self.batch_size, + (len(idx) // self.batch_size) * self.batch_size, + ) + + batches = [] + for i in range(0, len(idx), n_pool): + # the indices in the pool are sorted by their number of objects + idx_pool = idx[i : i + n_pool] + idx_pool = sorted(idx_pool, key=lambda i: self.n_objects[i]) + + # such that we can create batches where each element has a similar number of objects + jj = np.arange(0, len(idx_pool), self.batch_size) + np.random.shuffle(jj) + + for j in jj: + # dont drop_last, as this leads to a lot of lightning problems.... + # if j + self.batch_size > len(idx_pool): # assume drop_last=True + # continue + batch = idx_pool[j : j + self.batch_size] + batches.append(batch) + return batches + + def __iter__(self): + idx = np.arange(len(self.n_objects)) + batches = self.sample_batches(idx) + return iter(batches) + + def __len__(self): + if self.num_samples is not None: + return self.num_samples // self.batch_size + else: + return len(self.n_objects) // self.batch_size + + +class BalancedDistributedSampler(DistributedSampler): + def __init__( + self, + dataset: Dataset, + batch_size: int, + n_pool: int, + num_samples: int, + weight_by_ndivs: bool = False, + weight_by_dataset: bool = False, + *args, + **kwargs, + ) -> None: + super().__init__(dataset=dataset, *args, drop_last=True, **kwargs) + self._balanced_batch_sampler = BalancedBatchSampler( + dataset, + batch_size=batch_size, + n_pool=n_pool, + num_samples=max(1, num_samples // self.num_replicas), + weight_by_ndivs=weight_by_ndivs, + weight_by_dataset=weight_by_dataset, + ) + + def __len__(self) -> int: + if self.num_samples is not None: + return self._balanced_batch_sampler.num_samples + else: + return super().__len__() + + def __iter__(self): + indices = list(super().__iter__()) + batches = self._balanced_batch_sampler.sample_batches(indices) + for batch in batches: + yield from batch + + +# class BalancedDataModule(LightningDataModule): +# def __init__( +# self, +# input_train: list, +# input_val: list, +# cachedir: str, +# augment: int, +# distributed: bool, +# dataset_kwargs: dict, +# sampler_kwargs: dict, +# loader_kwargs: dict, +# ): +# super().__init__() +# self.input_train = input_train +# self.input_val = input_val +# self.cachedir = cachedir +# self.augment = augment +# self.distributed = distributed +# self.dataset_kwargs = dataset_kwargs +# self.sampler_kwargs = sampler_kwargs +# self.loader_kwargs = loader_kwargs + +# def prepare_data(self): +# """Loads and caches the datasets if not already done. + +# Running on the main CPU process. +# """ +# CTCData = cache_class(self.cachedir) +# datasets = dict() +# for split, inps in zip( +# ("train", "val"), +# (self.input_train, self.input_val), +# ): +# logger.info(f"Loading {split.upper()} data") +# start = default_timer() +# datasets[split] = torch.utils.data.ConcatDataset( +# CTCData( +# root=Path(inp), +# augment=self.augment if split == "train" else 0, +# **self.dataset_kwargs, +# ) +# for inp in inps +# ) +# logger.info( +# f"Loaded {len(datasets[split])} {split.upper()} samples (in" +# f" {(default_timer() - start):.1f} s)\n\n" +# ) + +# del datasets + +# def setup(self, stage: str): +# CTCData = cache_class(self.cachedir) +# self.datasets = dict() +# for split, inps in zip( +# ("train", "val"), +# (self.input_train, self.input_val), +# ): +# logger.info(f"Loading {split.upper()} data") +# start = default_timer() +# self.datasets[split] = torch.utils.data.ConcatDataset( +# CTCData( +# root=Path(inp), +# augment=self.augment if split == "train" else 0, +# **self.dataset_kwargs, +# ) +# for inp in inps +# ) +# logger.info( +# f"Loaded {len(self.datasets[split])} {split.upper()} samples (in" +# f" {(default_timer() - start):.1f} s)\n\n" +# ) + +# def train_dataloader(self): +# loader_kwargs = self.loader_kwargs.copy() +# if self.distributed: +# sampler = BalancedDistributedSampler( +# self.datasets["train"], +# **self.sampler_kwargs, +# ) +# batch_sampler = None +# else: +# sampler = None +# batch_sampler = BalancedBatchSampler( +# self.datasets["train"], +# **self.sampler_kwargs, +# ) +# if not loader_kwargs["batch_size"] == batch_sampler.batch_size: +# raise ValueError( +# f"Batch size in loader_kwargs ({loader_kwargs['batch_size']}) and sampler_kwargs ({batch_sampler.batch_size}) must match" +# ) +# del loader_kwargs["batch_size"] + +# loader = DataLoader( +# self.datasets["train"], +# sampler=sampler, +# batch_sampler=batch_sampler, +# **loader_kwargs, +# ) +# return loader + +# def val_dataloader(self): +# val_loader_kwargs = deepcopy(self.loader_kwargs) +# val_loader_kwargs["persistent_workers"] = False +# val_loader_kwargs["num_workers"] = 1 +# return DataLoader( +# self.datasets["val"], +# shuffle=False, +# **val_loader_kwargs, +# ) diff --git a/models/tra_post_model/trackastra/data/example_data.py b/models/tra_post_model/trackastra/data/example_data.py new file mode 100644 index 0000000000000000000000000000000000000000..2e9925eaa9a4fdc7ec7669b4fb17864d5b192f03 --- /dev/null +++ b/models/tra_post_model/trackastra/data/example_data.py @@ -0,0 +1,48 @@ +from pathlib import Path + +import tifffile + +root = Path(__file__).parent / "resources" + + +def example_data_bacteria(): + """Bacteria images and masks from. + + Van Vliet et al. Spatially Correlated Gene Expression in Bacterial Groups: The Role of Lineage History, Spatial Gradients, and Cell-Cell Interactions (2018) + https://doi.org/10.1016/j.cels.2018.03.009 + + subset of timelapse trpL/150310-11 + """ + img = tifffile.imread(root / "trpL_150310-11_img.tif") + mask = tifffile.imread(root / "trpL_150310-11_mask.tif") + return img, mask + + +def example_data_hela(): + """Hela data from the cell tracking challenge. + + Neumann et al. Phenotypic profiling of the human genome by time-lapse microscopy reveals cell division genes (2010) + + subset of Fluo-N2DL-HeLa/train/02 + """ + img = tifffile.imread(root / "Fluo_Hela_02_img.tif") + mask = tifffile.imread(root / "Fluo_Hela_02_ERR_SEG.tif") + print(img.shape, mask.shape) + return img, mask + + +def example_data_fluo_3d(): + """Fluo-N3DH-CHO data from the cell tracking challenge. + + Dzyubachyk et al. Advanced Level-Set-Based Cell Tracking in Time-Lapse Fluorescence Microscopy (2010) + + subset of Fluo-N3DH-CHO/train/02 + """ + img = tifffile.imread(root / "Fluo-N3DH-CHO_02_img.tif") + mask = tifffile.imread(root / "Fluo-N3DH-CHO_02_ERR_SEG.tif") + return img, mask + +def data_hela(): + img = tifffile.imread("02_imgs.tif") + mask = tifffile.imread("02_masks.tif") + return img, mask \ No newline at end of file diff --git a/models/tra_post_model/trackastra/data/features.py b/models/tra_post_model/trackastra/data/features.py new file mode 100644 index 0000000000000000000000000000000000000000..0c02f7817e57a204f05b0f0708e07c0c949de78a --- /dev/null +++ b/models/tra_post_model/trackastra/data/features.py @@ -0,0 +1,148 @@ +import itertools + +import numpy as np +import pandas as pd +from skimage.measure import regionprops_table + +# the property keys that are supported for 2 and 3 dim + +_PROPERTIES = { + 2: { + # FIXME: The only image regionprop possible now (when compressing) is mean_intensity, + # since we store a mask with the mean intensity of each detection as the image. + "regionprops": ( + "label", + "area", + "intensity_mean", + "eccentricity", + "solidity", + "inertia_tensor", + ), + # faster + "regionprops2": ( + "label", + "area", + "intensity_mean", + "inertia_tensor", + ), + "patch_regionprops": ( + "label", + "area", + "intensity_mean", + "inertia_tensor", + ), + }, + 3: { + "regionprops2": ( + "label", + "area", + "intensity_mean", + "inertia_tensor", + ), + "patch_regionprops": ( + "label", + "area", + "intensity_mean", + "inertia_tensor", + ), + }, +} + + +def extract_features_regionprops( + mask: np.ndarray, + img: np.ndarray, + labels: np.ndarray, + properties="regionprops2", +): + ndim = mask.ndim + assert ndim in (2, 3) + assert mask.shape == img.shape + + prop_dict = _PROPERTIES[ndim] + if properties not in prop_dict: + raise ValueError(f"properties must be one of {prop_dict.keys()}") + properties_tuple = prop_dict[properties] + + assert properties_tuple[0] == "label" + + labels = np.asarray(labels) + + # remove mask labels that are not present + # not needed, remove for speed + # mask[~np.isin(mask, labels)] = 0 + + df = pd.DataFrame( + regionprops_table(mask, intensity_image=img, properties=properties_tuple) + ) + assert df.columns[0] == "label" + assert df.columns[1] == "area" + + # the bnumber of inertia tensor columns depends on the dimensionality + n_cols_inertia = ndim**2 + assert np.all(["inertia_tensor" in col for col in df.columns[-n_cols_inertia:]]) + + # Hack for backwards compatibility + if properties in ("regionprops", "patch_regionprops"): + # Nice for conceptual clarity, but does not matter for speed + # drop upper triangular part of symmetric inertia tensor + for i, j in itertools.product(range(ndim), repeat=2): + if i > j: + df.drop(f"inertia_tensor-{i}-{j}", axis=1, inplace=True) + + table = df.to_numpy() + table[:, 1] *= 0.001 + table[:, -n_cols_inertia:] *= 0.01 + # reorder according to labels + features = np.zeros((len(labels), len(df.columns) - 1)) + + # faster than iterating over pandas dataframe + for row in table: + # old version with tuple indexing, slow. + # n = labels.index(int(row.label)) + # features[n] = row.to_numpy()[1:] + + # Only process regions present in the labels + n = np.where(labels == int(row[0]))[0] + if len(n) > 0: + # Remove label column (0)! + features[n[0]] = row[1:] + + return features + + +def extract_features_patch( + mask: np.ndarray, + img: np.ndarray, + coords: np.ndarray, + labels: np.ndarray, + width_patch: int = 16, +): + """16x16 Image patch around detection.""" + ndim = mask.ndim + assert ndim in (2, 3) and mask.shape == img.shape + if len(coords) == 0: + return np.zeros((0, width_patch * width_patch)) + + pads = (width_patch // 2,) * ndim + + img = np.pad( + img, + tuple((p, p) for p in pads), + mode="constant", + ) + + coords = coords.astype(int) + np.array(pads) + + ss = tuple( + tuple(slice(_c - width_patch // 2, _c + width_patch // 2) for _c in c) + for c in coords + ) + fs = tuple(img[_s] for _s in ss) + + # max project along z if 3D + if ndim == 3: + fs = tuple(f.max(0) for f in fs) + + features = np.stack([f.flatten() for f in fs]) + return features diff --git a/models/tra_post_model/trackastra/data/matching.py b/models/tra_post_model/trackastra/data/matching.py new file mode 100644 index 0000000000000000000000000000000000000000..d684fc52d27a0597fb780f0c96befd17f2e6ba23 --- /dev/null +++ b/models/tra_post_model/trackastra/data/matching.py @@ -0,0 +1,251 @@ +# Adapted from https://github.com/stardist/stardist/blob/master/stardist/matching.py + +import numpy as np +from numba import jit +from scipy.optimize import linear_sum_assignment +from scipy.spatial.distance import cdist +from skimage.measure import regionprops + +matching_criteria = dict() + + +def label_are_sequential(y): + """Returns true if y has only sequential labels from 1...""" + labels = np.unique(y) + return (set(labels) - {0}) == set(range(1, 1 + labels.max())) + + +def is_array_of_integers(y): + return isinstance(y, np.ndarray) and np.issubdtype(y.dtype, np.integer) + + +def _check_label_array(y, name=None, check_sequential=False): + err = ValueError( + "{label} must be an array of {integers}.".format( + label="labels" if name is None else name, + integers=("sequential " if check_sequential else "") + + "non-negative integers", + ) + ) + + if not is_array_of_integers(y): + raise err + if len(y) == 0: + return True + if check_sequential and not label_are_sequential(y): + raise err + else: + if not y.min() >= 0: + raise err + return True + + +def label_overlap(x, y, check=True): + if check: + _check_label_array(x, "x", True) + _check_label_array(y, "y", True) + if not x.shape == y.shape: + raise ValueError("x and y must have the same shape") + return _label_overlap(x, y) + + +@jit(nopython=True) +def _label_overlap(x, y): + x = x.ravel() + y = y.ravel() + overlap = np.zeros((1 + x.max(), 1 + y.max()), dtype=np.uint32) + for i in range(len(x)): + overlap[x[i], y[i]] += 1 + return overlap[1:, 1:] + + +def _safe_divide(x, y, eps=1e-10): + """Computes a safe divide which returns 0 if y is zero.""" + if np.isscalar(x) and np.isscalar(y): + return x / y if np.abs(y) > eps else 0.0 + else: + out = np.zeros(np.broadcast(x, y).shape, np.float32) + np.divide(x, y, out=out, where=np.abs(y) > eps) + return out + + +def intersection_over_union(overlap): + _check_label_array(overlap, "overlap") + if np.sum(overlap) == 0: + return overlap + n_pixels_pred = np.sum(overlap, axis=0, keepdims=True) + n_pixels_true = np.sum(overlap, axis=1, keepdims=True) + return _safe_divide(overlap, (n_pixels_pred + n_pixels_true - overlap)) + + +def dist_score(y_true, y_pred, max_distance: int = 10): + """Compute distance score between centroids of regions in y_true and y_pred + and returns a score matrix of shape (n_true, n_pred) with values in [0,1] + where + distance >= max_distance -> score = 0 + distance = 0 -> score = 1. + """ + c_true = np.stack([r.centroid for r in regionprops(y_true)], axis=0) + c_pred = np.stack([r.centroid for r in regionprops(y_pred)], axis=0) + dist = np.minimum(cdist(c_true, c_pred), max_distance) + score = 1 - dist / max_distance + return score + + +# copied from scikit-image master for now (remove when part of a release) +def relabel_sequential(label_field, offset=1): + """Relabel arbitrary labels to {`offset`, ... `offset` + number_of_labels}. + + This function also returns the forward map (mapping the original labels to + the reduced labels) and the inverse map (mapping the reduced labels back + to the original ones). + + Parameters + ---------- + label_field : numpy array of int, arbitrary shape + An array of labels, which must be non-negative integers. + offset : int, optional + The return labels will start at `offset`, which should be + strictly positive. + + Returns: + ------- + relabeled : numpy array of int, same shape as `label_field` + The input label field with labels mapped to + {offset, ..., number_of_labels + offset - 1}. + The data type will be the same as `label_field`, except when + offset + number_of_labels causes overflow of the current data type. + forward_map : numpy array of int, shape ``(label_field.max() + 1,)`` + The map from the original label space to the returned label + space. Can be used to re-apply the same mapping. See examples + for usage. The data type will be the same as `relabeled`. + inverse_map : 1D numpy array of int, of length offset + number of labels + The map from the new label space to the original space. This + can be used to reconstruct the original label field from the + relabeled one. The data type will be the same as `relabeled`. + + Notes: + ----- + The label 0 is assumed to denote the background and is never remapped. + + The forward map can be extremely big for some inputs, since its + length is given by the maximum of the label field. However, in most + situations, ``label_field.max()`` is much smaller than + ``label_field.size``, and in these cases the forward map is + guaranteed to be smaller than either the input or output images. + + Examples: + -------- + >>> from skimage.segmentation import relabel_sequential + >>> label_field = np.array([1, 1, 5, 5, 8, 99, 42]) + >>> relab, fw, inv = relabel_sequential(label_field) + >>> relab + array([1, 1, 2, 2, 3, 5, 4]) + >>> fw + array([0, 1, 0, 0, 0, 2, 0, 0, 3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 5]) + >>> inv + array([ 0, 1, 5, 8, 42, 99]) + >>> (fw[label_field] == relab).all() + True + >>> (inv[relab] == label_field).all() + True + >>> relab, fw, inv = relabel_sequential(label_field, offset=5) + >>> relab + array([5, 5, 6, 6, 7, 9, 8]) + """ + offset = int(offset) + if offset <= 0: + raise ValueError("Offset must be strictly positive.") + if np.min(label_field) < 0: + raise ValueError("Cannot relabel array that contains negative values.") + max_label = int(label_field.max()) # Ensure max_label is an integer + if not np.issubdtype(label_field.dtype, np.integer): + new_type = np.min_scalar_type(max_label) + label_field = label_field.astype(new_type) + labels = np.unique(label_field) + labels0 = labels[labels != 0] + new_max_label = offset - 1 + len(labels0) + new_labels0 = np.arange(offset, new_max_label + 1) + output_type = label_field.dtype + required_type = np.min_scalar_type(new_max_label) + if np.dtype(required_type).itemsize > np.dtype(label_field.dtype).itemsize: + output_type = required_type + forward_map = np.zeros(max_label + 1, dtype=output_type) + forward_map[labels0] = new_labels0 + inverse_map = np.zeros(new_max_label + 1, dtype=output_type) + inverse_map[offset:] = labels0 + relabeled = forward_map[label_field] + return relabeled, forward_map, inverse_map + + +def matching(y_true, y_pred, threshold=0.5, max_distance: int = 16): + """Computes IoU and distance score between all pairs of regions in y_true and y_pred. + + returns the true/pred matching based on the higher of the two scores for each pair of regions + + Parameters + ---------- + y_true: ndarray + ground truth label image (integer valued) + y_pred: ndarray + predicted label image (integer valued) + threshold: float + threshold for matching criterion (default 0.5) + max_distance: int + maximum distance between centroids of regions in y_true and y_pred (default 16) + + Returns: + ------- + gt_pred: tuple + tuple of all matched region label pairs in y_true and y_pred + + + """ + y_true, y_pred = y_true.astype(np.int32), y_pred.astype(np.int32) + _check_label_array(y_true, "y_true") + _check_label_array(y_pred, "y_pred") + if not y_true.shape == y_pred.shape: + raise ValueError( + f"y_true ({y_true.shape}) and y_pred ({y_pred.shape}) have different shapes" + ) + if threshold is None: + threshold = 0 + + threshold = float(threshold) if np.isscalar(threshold) else map(float, threshold) + + y_true, _, map_rev_true = relabel_sequential(y_true) + y_pred, _, map_rev_pred = relabel_sequential(y_pred) + + overlap = label_overlap(y_true, y_pred, check=False) + + scores_iou = intersection_over_union(overlap) + scores_dist = dist_score(y_true, y_pred, max_distance) + scores = np.maximum(scores_iou, scores_dist) + + assert 0 <= np.min(scores) <= np.max(scores) <= 1 + + n_true, n_pred = scores.shape + n_matched = min(n_true, n_pred) + + # not_trivial = n_matched > 0 and np.any(scores >= thr) + not_trivial = n_matched > 0 + if not_trivial: + # compute optimal matching with scores as tie-breaker + costs = -(scores >= threshold).astype(float) - scores / (2 * n_matched) + true_ind, pred_ind = linear_sum_assignment(costs) + assert n_matched == len(true_ind) == len(pred_ind) + match_ok = scores[true_ind, pred_ind] >= threshold + true_ind = true_ind[match_ok] + pred_ind = pred_ind[match_ok] + matched = tuple( + (int(map_rev_true[i]), int(map_rev_pred[j])) + for i, j in zip(1 + true_ind, 1 + pred_ind) + ) + else: + matched = () + + return matched diff --git a/models/tra_post_model/trackastra/data/utils.py b/models/tra_post_model/trackastra/data/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..87d2f79231a50ba442a1b5886167c46691816570 --- /dev/null +++ b/models/tra_post_model/trackastra/data/utils.py @@ -0,0 +1,232 @@ +import logging +import sys +from pathlib import Path + +import numpy as np +import pandas as pd + +# from .data import CTCData +import tifffile +from tqdm import tqdm +from typing import Optional, Union, Tuple + +logger = logging.getLogger(__name__) + + +def load_tiff_timeseries( + dir: Path, + dtype: Optional[Union[str, type]] = None, + downscale: Optional[Tuple[int, ...]] = None, + start_frame: int = 0, + end_frame: Optional[int] = None, +) -> np.ndarray: + """Loads a folder of `.tif` or `.tiff` files into a numpy array. + Each file is interpreted as a frame of a time series. + + Args: + folder: + dtype: + downscale: One int for each dimension of the data. Avoids memory overhead. + start_frame: The first frame to load. + end_frame: The last frame to load. + + Returns: + np.ndarray: The loaded data. + """ + # TODO make safe for label arrays + logger.debug(f"Loading tiffs from {dir} as {dtype}") + files = sorted(list(dir.glob("*.tif")) + list(dir.glob("*.tiff")))[ + start_frame:end_frame + ] + shape = tifffile.imread(files[0]).shape + if downscale: + assert len(downscale) == len(shape) + else: + downscale = (1,) * len(shape) + + files = files[:: downscale[0]] + + x = [] + for f in tqdm( + files, + leave=False, + desc=f"Loading [{start_frame}:{end_frame}:{downscale[0]}]", + ): + _x = tifffile.imread(f) + if dtype: + _x = _x.astype(dtype) + assert _x.shape == shape + slices = tuple(slice(None, None, d) for d in downscale[1:]) + _x = _x[slices] + x.append(_x) + + x = np.stack(x) + logger.debug(f"Loaded array of shape {x.shape} from {dir}") + return x + + +def load_tracklet_links(folder: Path) -> pd.DataFrame: + candidates = [ + folder / "man_track.txt", + folder / "res_track.txt", + ] + for c in candidates: + if c.exists(): + path = c + break + else: + raise FileNotFoundError(f"Could not find tracklet links in {folder}") + + df = pd.read_csv( + path, + delimiter=" ", + names=["label", "t1", "t2", "parent"], + dtype=int, + ) + # Remove invalid tracks with t2 > t1 + df = df[df.t1 <= df.t2] + + n_dets = (df.t2 - df.t1 + 1).sum() + logger.debug(f"{folder} has {n_dets} detections") + + n_divs = (df[df.parent != 0]["parent"].value_counts() == 2).sum() + logger.debug(f"{folder} has {n_divs} divisions") + return df + + +def filter_track_df( + df: pd.DataFrame, + start_frame: int = 0, + end_frame: int = sys.maxsize, + downscale: int = 1, +) -> pd.DataFrame: + """Only keep tracklets that are present in the given time interval.""" + df.columns = ["label", "t1", "t2", "parent"] + # only retain cells in interval + df = df[(df.t2 >= start_frame) & (df.t1 < end_frame)] + + # shift start and end of each cell + df.t1 = df.t1 - start_frame + df.t2 = df.t2 - start_frame + # set start/end to min/max + df.t1 = df.t1.clip(0, end_frame - start_frame - 1) + df.t2 = df.t2.clip(0, end_frame - start_frame - 1) + # set all parents to 0 that are not in the interval + df.loc[~df.parent.isin(df.label), "parent"] = 0 + + if downscale > 1: + if start_frame % downscale != 0: + raise ValueError("start_frame must be a multiple of downscale") + + logger.debug(f"Temporal downscaling of tracklet links by {downscale}") + + # remove tracklets that have been fully deleted by temporal downsampling + + mask = ( + # (df["t2"] - df["t1"] < downscale - 1) + (df["t1"] % downscale != 0) + & (df["t2"] % downscale != 0) + & (df["t1"] // downscale == df["t2"] // downscale) + ) + logger.debug( + f"Remove {mask.sum()} tracklets that are fully deleted by downsampling" + ) + logger.debug(f"Remove {df[mask]}") + + df = df[~mask] + # set parent to 0 if it has been deleted + df.loc[~df.parent.isin(df.label), "parent"] = 0 + + df["t2"] = (df["t2"] / float(downscale)).apply(np.floor).astype(int) + df["t1"] = (df["t1"] / float(downscale)).apply(np.ceil).astype(int) + + # Correct for edge case of single frame tracklet + assert np.all(df["t1"] == np.minimum(df["t1"], df["t2"])) + + return df + + +# TODO fix +# def dataset_to_ctc(dataset: CTCData, path, start: int = 0, stop: int | None = None): +# """save dataset to ctc format for debugging purposes""" +# out = Path(path) +# print(f"Saving dataset to {out}") +# out_img = out / "img" +# out_img.mkdir(exist_ok=True, parents=True) +# out_mask = out / "TRA" +# out_mask.mkdir(exist_ok=True, parents=True) +# if stop is None: +# stop = len(self) +# lines = [] +# masks, imgs = [], [] +# t_offset = 0 +# max_mask = 0 +# n_lines = 0 +# all_coords = [] +# for i in tqdm(range(start, stop)): +# d = dataset.__getitem__(i, return_dense=True) +# mask = d["mask"].numpy() +# mask[mask > 0] += max_mask +# max_mask = max(max_mask, mask.max()) +# masks.extend(mask) +# imgs.extend(d["img"].numpy()) +# # add vertices +# coords = d["coords0"].numpy() +# ts, coords = coords[:, 0].astype(int), coords[:, 1:] +# A = d["assoc_matrix"].numpy() +# t_unique = sorted(np.unique(ts)) +# for t1, t2 in zip(t_unique[:-1], t_unique[1:]): +# A_sub = A[ts == t1][:, ts == t2] +# for i, a in enumerate(A_sub): + +# v1 = coords[ts == t1][i] +# for j in np.where(a > 0)[0]: +# v2 = coords[ts == t2][j] +# # lines.append( +# # { +# # "index": n_lines, +# # "shape-type": "line", +# # "vertex-index": 0, +# # "axis-0": t2 + t_offset, +# # "axis-1": v1[0], +# # "axis-2": v1[1], +# # } +# # ) +# # lines.append( +# # { +# # "index": n_lines, +# # "shape-type": "line", +# # "vertex-index": 1, +# # "axis-0": t2 + t_offset, +# # "axis-1": v2[0], +# # "axis-2": v2[1], +# # } +# # ) +# lines.append([n_lines, "line", 0, t2 + t_offset] + v1.tolist()) +# lines.append([n_lines, "line", 1, t2 + t_offset] + v2.tolist()) +# n_lines += 1 + +# c = d["coords0"].numpy() +# c[:, 0] += t_offset +# all_coords.extend(c) +# t_offset += len(mask) + +# ax_cols = [f"axis-{i}" for i in range(dataset.ndim + 1)] +# df = pd.DataFrame(lines, columns=["index", "shape-type", "vertex-index"] + ax_cols) +# df.to_csv(out / "lines.csv", index=False) + +# df_c = pd.DataFrame(all_coords, columns=ax_cols) +# df_c.to_csv(out / "coords.csv", index=False) + +# for i, m in enumerate(imgs): +# # tifffile.imwrite(out_img/f'img_{i:04d}.tif', m) +# if dataset.ndim == 2: +# imageio.imwrite( +# out_img / f"img_{i:04d}.jpg", +# np.clip(20 + 100 * m, 0, 255).astype(np.uint8), +# ) + +# for i, m in enumerate(masks): +# tifffile.imwrite(out_mask / f"mask_{i:04d}.tif", m, compression="zstd") + +# return d diff --git a/models/tra_post_model/trackastra/data/wrfeat.py b/models/tra_post_model/trackastra/data/wrfeat.py new file mode 100644 index 0000000000000000000000000000000000000000..5d71238f7bf6341d27f091b2fe84b2591ef536cc --- /dev/null +++ b/models/tra_post_model/trackastra/data/wrfeat.py @@ -0,0 +1,655 @@ +"""Regionprops features and its augmentations. +WindowedRegionFeatures (WRFeatures) is a class that holds regionprops features for a windowed track region. +""" + +import itertools +import logging +from collections import OrderedDict +from collections.abc import Iterable #, Sequence +from functools import reduce +from typing import Literal + +import joblib +import numpy as np +import pandas as pd +from edt import edt +from skimage.measure import regionprops, regionprops_table +from tqdm import tqdm +from typing import Tuple, Optional, Sequence, Union, List +import typing + +try: + from .utils import load_tiff_timeseries +except: + from utils import load_tiff_timeseries +import torch +logger = logging.getLogger(__name__) + +_PROPERTIES = { + "regionprops": ( + "area", + "intensity_mean", + "intensity_max", + "intensity_min", + "inertia_tensor", + ), + "regionprops2": ( + "equivalent_diameter_area", + "intensity_mean", + "inertia_tensor", + "border_dist", + ), +} + + +def _filter_points( + points: np.ndarray, shape: Tuple[int], origin: Optional[Tuple[int]] = None +) -> np.ndarray: + """Returns indices of points that are inside the shape extent and given origin.""" + ndim = points.shape[-1] + if origin is None: + origin = (0,) * ndim + + idx = tuple( + np.logical_and(points[:, i] >= origin[i], points[:, i] < origin[i] + shape[i]) + for i in range(ndim) + ) + idx = np.where(np.all(idx, axis=0))[0] + return idx + + +def _border_dist(mask: np.ndarray, cutoff: float = 5): + """Returns distance to border normalized to 0 (at least cutoff away) and 1 (at border).""" + border = np.zeros_like(mask) + + # only apply to last two dimensions + ss = tuple( + slice(None) if i < mask.ndim - 2 else slice(1, -1) + for i, s in enumerate(mask.shape) + ) + border[ss] = 1 + dist = 1 - np.minimum(edt(border) / cutoff, 1) + return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist)) + + +def _border_dist_fast(mask: np.ndarray, cutoff: float = 5): + cutoff = int(cutoff) + border = np.ones(mask.shape, dtype=np.float32) + ndim = len(mask.shape) + + for axis, size in enumerate(mask.shape): + # Create fade values for the band [0, cutoff) + band_vals = np.arange(cutoff, dtype=np.float32) / cutoff + + # Build slices for the low border + low_slices = [slice(None)] * ndim + low_slices[axis] = slice(0, cutoff) + border_low = border[tuple(low_slices)] + border_low_vals = np.minimum( + border_low, band_vals[(...,) + (None,) * (ndim - axis - 1)] + ) + border[tuple(low_slices)] = border_low_vals + + # Build slices for the high border + high_slices = [slice(None)] * ndim + high_slices[axis] = slice(size - cutoff, size) + band_vals_rev = band_vals[::-1] + border_high = border[tuple(high_slices)] + border_high_vals = np.minimum( + border_high, band_vals_rev[(...,) + (None,) * (ndim - axis - 1)] + ) + border[tuple(high_slices)] = border_high_vals + + dist = 1 - border + return tuple(r.intensity_max for r in regionprops(mask, intensity_image=dist)) + + +class WRFeatures: + """regionprops features for a windowed track region.""" + + def __init__( + self, + coords: np.ndarray, + labels: np.ndarray, + timepoints: np.ndarray, + features: typing.OrderedDict[str, np.ndarray], + ): + self.ndim = coords.shape[-1] + if self.ndim not in (2, 3): + raise ValueError("Only 2D or 3D data is supported") + + self.coords = coords + self.labels = labels + self.features = features.copy() + self.timepoints = timepoints + + def __repr__(self): + s = ( + f"WindowRegionFeatures(ndim={self.ndim}, nregions={len(self.labels)}," + f" ntimepoints={len(np.unique(self.timepoints))})\n\n" + ) + for k, v in self.features.items(): + s += f"{k:>20} -> {v.shape}\n" + return s + + @property + def features_stacked(self): + return np.concatenate([v for k, v in self.features.items()], axis=-1) + + def __len__(self): + return len(self.labels) + + def __getitem__(self, key): + if key in self.features: + return self.features[key] + else: + raise KeyError(f"Key {key} not found in features") + + @classmethod + def concat(cls, feats: Sequence["WRFeatures"]) -> "WRFeatures": + """Concatenate multiple WRFeatures into a single one.""" + if len(feats) == 0: + raise ValueError("Cannot concatenate empty list of features") + return reduce(lambda x, y: x + y, feats) + + def __add__(self, other: "WRFeatures") -> "WRFeatures": + """Concatenate two WRFeatures.""" + if self.ndim != other.ndim: + raise ValueError("Cannot concatenate features of different dimensions") + if self.features.keys() != other.features.keys(): + raise ValueError("Cannot concatenate features with different properties") + + coords = np.concatenate([self.coords, other.coords], axis=0) + labels = np.concatenate([self.labels, other.labels], axis=0) + timepoints = np.concatenate([self.timepoints, other.timepoints], axis=0) + + features = OrderedDict( + (k, np.concatenate([v, other.features[k]], axis=0)) + for k, v in self.features.items() + ) + + return WRFeatures( + coords=coords, labels=labels, timepoints=timepoints, features=features + ) + + @classmethod + def from_mask_img( + cls, + mask: np.ndarray, + img: np.ndarray, + properties="regionprops2", + t_start: int = 0, + ): + img = np.asarray(img) + mask = np.asarray(mask) + + _ntime, ndim = mask.shape[0], mask.ndim - 1 + if ndim not in (2, 3): + raise ValueError("Only 2D or 3D data is supported") + + properties = tuple(_PROPERTIES[properties]) + if "label" in properties or "centroid" in properties: + raise ValueError( + f"label and centroid should not be in properties {properties}" + ) + + if "border_dist" in properties: + use_border_dist = True + # remove border_dist from properties + properties = tuple(p for p in properties if p != "border_dist") + else: + use_border_dist = False + + df_properties = ("label", "centroid", *properties) + dfs = [] + for i, (y, x) in enumerate(zip(mask, img)): + _df = pd.DataFrame( + regionprops_table(y, intensity_image=x, properties=df_properties) + ) + _df["timepoint"] = i + t_start + + if use_border_dist: + _df["border_dist"] = _border_dist_fast(y) + + dfs.append(_df) + df = pd.concat(dfs) + + if use_border_dist: + properties = (*properties, "border_dist") + + timepoints = df["timepoint"].values.astype(np.int32) + labels = df["label"].values.astype(np.int32) + coords = df[[f"centroid-{i}" for i in range(ndim)]].values.astype(np.float32) + + features = OrderedDict( + ( + p, + np.stack( + [ + df[c].values.astype(np.float32) + for c in df.columns + if c.startswith(p) + ], + axis=-1, + ), + ) + for p in properties + ) + + return cls( + coords=coords, labels=labels, timepoints=timepoints, features=features + ) + + +# augmentations + + +class WRRandomCrop: + """windowed region random crop augmentation.""" + + def __init__( + self, + crop_size: Optional[Union[int, Tuple[int]]] = None, + ndim: int = 2, + ) -> None: + """crop_size: tuple of int + can be tuple of length 1 (all dimensions) + of length ndim (y,x,...) + of length 2*ndim (y1,y2, x1,x2, ...). + """ + if isinstance(crop_size, int): + crop_size = (crop_size,) * 2 * ndim + elif isinstance(crop_size, Iterable): + pass + else: + raise ValueError(f"{crop_size} has to be int or tuple of int") + + if len(crop_size) == 1: + crop_size = (crop_size[0],) * 2 * ndim + elif len(crop_size) == ndim: + crop_size = tuple(itertools.chain(*tuple((c, c) for c in crop_size))) + elif len(crop_size) == 2 * ndim: + pass + else: + raise ValueError(f"crop_size has to be of length 1, {ndim}, or {2 * ndim}") + + crop_size = np.array(crop_size) + self._ndim = ndim + self._crop_bounds = crop_size[::2], crop_size[1::2] + self._rng = np.random.RandomState() + + def __call__(self, features: WRFeatures): + crop_size = self._rng.randint(self._crop_bounds[0], self._crop_bounds[1] + 1) + points = features.coords + + if len(points) == 0: + print("No points given, cannot ensure inside points") + return features + + # sample point and corner relative to it + + _idx = np.random.randint(len(points)) + corner = ( + points[_idx] + - crop_size + + 1 + + self._rng.randint(crop_size // 4, 3 * crop_size // 4) + ) + + idx = _filter_points(points, shape=crop_size, origin=corner) + + return ( + WRFeatures( + coords=points[idx], + labels=features.labels[idx], + timepoints=features.timepoints[idx], + features=OrderedDict((k, v[idx]) for k, v in features.features.items()), + ), + idx, + ) + + +class WRBaseAugmentation: + def __init__(self, p: float = 0.5) -> None: + self._p = p + self._rng = np.random.RandomState() + + def __call__(self, features: WRFeatures): + if self._rng.rand() > self._p or len(features) == 0: + return features + return self._augment(features) + + def _augment(self, features: WRFeatures): + raise NotImplementedError() + + +class WRRandomFlip(WRBaseAugmentation): + def _augment(self, features: WRFeatures): + ndim = features.ndim + flip = self._rng.randint(0, 2, features.ndim) + points = features.coords.copy() + for i, f in enumerate(flip): + if f == 1: + points[:, ndim - i - 1] *= -1 + return WRFeatures( + coords=points, + labels=features.labels, + timepoints=features.timepoints, + features=features.features, + ) + + +def _scale_matrix(sz: float, sy: float, sx: float): + return np.diag([sz, sy, sx]) + + +# def _scale_matrix(sy: float, sx: float): +# return np.array([[1, 0, 0], [0, sy, 0], [0, 0, sx]]) + + +def _shear_matrix(shy: float, shx: float): + return np.array([[1, 0, 0], [0, 1 + shx * shy, shy], [0, shx, 1]]) + + +def _rotation_matrix(theta: float): + return np.array([ + [1, 0, 0], + [0, np.cos(theta), -np.sin(theta)], + [0, np.sin(theta), np.cos(theta)], + ]) + + +def _transform_affine(k: str, v: np.ndarray, M: np.ndarray): + ndim = len(M) + if k == "area": + v = np.linalg.det(M) * v + elif k == "equivalent_diameter_area": + v = np.linalg.det(M) ** (1 / len(M)) * v + + elif k == "inertia_tensor": + # v' = M * v * M^T + v = v.reshape(-1, ndim, ndim) + # v * M^T + v = np.einsum("ijk, mk -> ijm", v, M) + # M * v + v = np.einsum("ij, kjm -> kim", M, v) + v = v.reshape(-1, ndim * ndim) + elif k in ( + "intensity_mean", + "intensity_std", + "intensity_max", + "intensity_min", + "border_dist", + ): + pass + else: + raise ValueError(f"Don't know how to affinely transform {k}") + return v + + +class WRRandomAffine(WRBaseAugmentation): + def __init__( + self, + degrees: float = 10, + scale: float = (0.9, 1.1), + shear: float = (0.1, 0.1), + p: float = 0.5, + ): + super().__init__(p) + self.degrees = degrees if degrees is not None else 0 + self.scale = scale if scale is not None else (1, 1) + self.shear = shear if shear is not None else (0, 0) + + def _augment(self, features: WRFeatures): + degrees = self._rng.uniform(-self.degrees, self.degrees) / 180 * np.pi + scale = self._rng.uniform(*self.scale, 3) + shy = self._rng.uniform(-self.shear[0], self.shear[0]) + shx = self._rng.uniform(-self.shear[1], self.shear[1]) + + self._M = ( + _rotation_matrix(degrees) @ _scale_matrix(*scale) @ _shear_matrix(shy, shx) + ) + + # M is by default 3D , we need to remove the last dimension for 2D + self._M = self._M[-features.ndim :, -features.ndim :] + points = features.coords @ self._M.T + + feats = OrderedDict( + (k, _transform_affine(k, v, self._M)) for k, v in features.features.items() + ) + + return WRFeatures( + coords=points, + labels=features.labels, + timepoints=features.timepoints, + features=feats, + ) + + +class WRRandomBrightness(WRBaseAugmentation): + def __init__( + self, + scale: Tuple[float] = (0.5, 2.0), + shift: Tuple[float] = (-0.1, 0.1), + p: float = 0.5, + ): + super().__init__(p) + self.scale = scale + self.shift = shift + + def _augment(self, features: WRFeatures): + scale = self._rng.uniform(*self.scale) + shift = self._rng.uniform(*self.shift) + + key_vals = [] + + for k, v in features.features.items(): + if "intensity" in k: + v = v * scale + shift + key_vals.append((k, v)) + feats = OrderedDict(key_vals) + return WRFeatures( + coords=features.coords, + labels=features.labels, + timepoints=features.timepoints, + features=feats, + ) + + +class WRRandomOffset(WRBaseAugmentation): + def __init__(self, offset: float = (-3, 3), p: float = 0.5): + super().__init__(p) + self.offset = offset + + def _augment(self, features: WRFeatures): + offset = self._rng.uniform(*self.offset, features.coords.shape) + coords = features.coords + offset + return WRFeatures( + coords=coords, + labels=features.labels, + timepoints=features.timepoints, + features=features.features, + ) + + +class WRRandomMovement(WRBaseAugmentation): + """random global linear shift.""" + + def __init__(self, offset: float = (-10, 10), p: float = 0.5): + super().__init__(p) + self.offset = offset + + def _augment(self, features: WRFeatures): + base_offset = self._rng.uniform(*self.offset, features.coords.shape[-1]) + tmin = features.timepoints.min() + offset = (features.timepoints[:, None] - tmin) * base_offset[None] + coords = features.coords + offset + + return WRFeatures( + coords=coords, + labels=features.labels, + timepoints=features.timepoints, + features=features.features, + ) + + +class WRAugmentationPipeline: + def __init__(self, augmentations: Sequence[WRBaseAugmentation]): + self.augmentations = augmentations + + def __call__(self, feats: WRFeatures): + for aug in self.augmentations: + feats = aug(feats) + return feats + + +def get_features( + detections: np.ndarray, + imgs: Optional[np.ndarray] = None, + features: Literal["none", "wrfeat"] = "wrfeat", + ndim: int = 2, + n_workers=0, + progbar_class=tqdm, +) -> List[WRFeatures]: + detections = _check_dimensions(detections, ndim) + imgs = _check_dimensions(imgs, ndim) + logger.info(f"Extracting features from {len(detections)} detections") + if n_workers > 0: + logger.info(f"Using {n_workers} processes for feature extraction") + features = joblib.Parallel(n_jobs=n_workers, backend="loky")( + joblib.delayed(WRFeatures.from_mask_img)( + # New axis for time component + mask=mask[np.newaxis, ...].copy(), + img=img[np.newaxis, ...].copy(), + t_start=t, + ) + for t, (mask, img) in progbar_class( + enumerate(zip(detections, imgs)), + total=len(imgs), + desc="Extracting features", + ) + ) + else: + logger.info("Using single process for feature extraction") + features = tuple( + WRFeatures.from_mask_img( + mask=mask[np.newaxis, ...], + img=img[np.newaxis, ...], + t_start=t, + ) + for t, (mask, img) in progbar_class( + enumerate(zip(detections, imgs)), + total=len(imgs), + desc="Extracting features", + ) + ) + + return features + + +def _check_dimensions(x: np.ndarray, ndim: int): + if ndim == 2 and not x.ndim == 3: + raise ValueError(f"Expected 2D data, got {x.ndim - 1}D data") + elif ndim == 3: + # if ndim=3 and data is two dimensional, it will be cast to 3D + if x.ndim == 3: + x = np.expand_dims(x, axis=1) + elif x.ndim == 4: + pass + else: + raise ValueError(f"Expected 3D data, got {x.ndim - 1}D data") + return x + + +def build_windows( + features: List[WRFeatures], window_size: int, progbar_class=tqdm +) -> List[dict]: + windows = [] + for t1, t2 in progbar_class( + zip(range(0, len(features)), range(window_size, len(features) + 1)), + total=len(features) - window_size + 1, + desc="Building windows", + ): + feat = WRFeatures.concat(features[t1:t2]) + + labels = feat.labels + timepoints = feat.timepoints + coords = feat.coords + + if len(feat) == 0: + coords = np.zeros((0, feat.ndim), dtype=int) + + w = dict( + coords=coords, + t1=t1, + labels=labels, + timepoints=timepoints, + features=feat.features_stacked, + ) + windows.append(w) + + logger.debug(f"Built {len(windows)} track windows.\n") + return windows + +def build_windows_sd( + features: List[WRFeatures], imgs_enc, imgs_stable, boxes, imgs, masks, window_size: int, progbar_class=tqdm +) -> List[dict]: + windows = [] + for t1, t2 in progbar_class( + zip(range(0, len(features)), range(window_size, len(features) + 1)), + total=len(features) - window_size + 1, + desc="Building windows", + ): + feat = WRFeatures.concat(features[t1:t2]) + + labels = feat.labels + timepoints = feat.timepoints + coords = feat.coords + + if len(feat) == 0: + coords = np.zeros((0, feat.ndim), dtype=int) + + w = dict( + coords=coords, + t1=t1, + labels=labels, + timepoints=timepoints, + features=feat.features_stacked, + img_enc=imgs_enc[t1:t2], + image_stable=imgs_stable[t1:t2], + boxes=boxes, + img=imgs[t1:t2], + mask=masks[t1:t2], + coords_t=torch.tensor(coords, dtype=torch.float32), + labels_t=torch.tensor(labels, dtype=torch.int32), + timepoints_t=torch.tensor(timepoints, dtype=torch.int64), + features_t=torch.tensor(feat.features_stacked, dtype=torch.float32), + img_t=torch.tensor(imgs[t1:t2], dtype=torch.float32), + mask_t=torch.tensor(masks[t1:t2], dtype=torch.int32), + ) + windows.append(w) + + logger.debug(f"Built {len(windows)} track windows.\n") + return windows + +if __name__ == "__main__": + imgs = load_tiff_timeseries( + # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01", + "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01", + ) + masks = load_tiff_timeseries( + # "/scratch0/data/celltracking/ctc_2024/Fluo-C3DL-MDA231/train/01_GT/TRA", + "/scratch0/data/celltracking/ctc_2024/Fluo-N2DL-HeLa/train/01_GT/TRA", + dtype=int, + ) + + features = get_features(detections=masks, imgs=imgs, ndim=3) + windows = build_windows(features, window_size=4) + + +# if __name__ == "__main__": +# y = np.zeros((1, 100, 100), np.uint8) +# y[:, 20:40, 20:60] = 1 +# x = y + np.random.normal(0, 0.1, y.shape) + +# f = WRFeatures.from_mask_img(y, x, properties=("intensity_mean", "area")) diff --git a/models/tra_post_model/trackastra/model/__init__.py b/models/tra_post_model/trackastra/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1bece74d678801567ae7cb9aa2c2e7d94753e7b7 --- /dev/null +++ b/models/tra_post_model/trackastra/model/__init__.py @@ -0,0 +1,4 @@ +# ruff: noqa: F401 + +from .model import TrackingTransformer +from .model_api import Trackastra diff --git a/models/tra_post_model/trackastra/model/model.py b/models/tra_post_model/trackastra/model/model.py new file mode 100644 index 0000000000000000000000000000000000000000..7e7b30493d5a11356d728e8e325d39a632de42e6 --- /dev/null +++ b/models/tra_post_model/trackastra/model/model.py @@ -0,0 +1,544 @@ +"""Transformer class.""" + +import logging +from collections import OrderedDict +from pathlib import Path +from typing import Literal + +import torch + +# from torch_geometric.nn import GATv2Conv +import yaml +from torch import nn + +import sys, os +sys.path.append(os.path.join(os.getcwd(), "External_Repos", "trackastra")) + +# NoPositionalEncoding, +from ..utils import blockwise_causal_norm + +from .model_parts import ( + FeedForward, + PositionalEncoding, + RelativePositionalAttention, +) + +# from memory_profiler import profile +logger = logging.getLogger(__name__) + + +class EncoderLayer(nn.Module): + def __init__( + self, + coord_dim: int = 2, + d_model=256, + num_heads=4, + dropout=0.1, + cutoff_spatial: int = 256, + window: int = 16, + positional_bias: Literal["bias", "rope", "none"] = "bias", + positional_bias_n_spatial: int = 32, + attn_dist_mode: str = "v0", + ): + super().__init__() + self.positional_bias = positional_bias + self.attn = RelativePositionalAttention( + coord_dim, + d_model, + num_heads, + cutoff_spatial=cutoff_spatial, + n_spatial=positional_bias_n_spatial, + cutoff_temporal=window, + n_temporal=window, + dropout=dropout, + mode=positional_bias, + attn_dist_mode=attn_dist_mode, + ) + self.mlp = FeedForward(d_model) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + + def forward( + self, + x: torch.Tensor, + coords: torch.Tensor, + padding_mask: torch.Tensor = None, + ): + x = self.norm1(x) + + # setting coords to None disables positional bias + a = self.attn( + x, + x, + x, + coords=coords if self.positional_bias else None, + padding_mask=padding_mask, + ) + + x = x + a + x = x + self.mlp(self.norm2(x)) + + return x + + +class DecoderLayer(nn.Module): + def __init__( + self, + coord_dim: int = 2, + d_model=256, + num_heads=4, + dropout=0.1, + window: int = 16, + cutoff_spatial: int = 256, + positional_bias: Literal["bias", "rope", "none"] = "bias", + positional_bias_n_spatial: int = 32, + attn_dist_mode: str = "v0", + ): + super().__init__() + self.positional_bias = positional_bias + self.attn = RelativePositionalAttention( + coord_dim, + d_model, + num_heads, + cutoff_spatial=cutoff_spatial, + n_spatial=positional_bias_n_spatial, + cutoff_temporal=window, + n_temporal=window, + dropout=dropout, + mode=positional_bias, + attn_dist_mode=attn_dist_mode, + ) + + self.mlp = FeedForward(d_model) + self.norm1 = nn.LayerNorm(d_model) + self.norm2 = nn.LayerNorm(d_model) + self.norm3 = nn.LayerNorm(d_model) + + def forward( + self, + x: torch.Tensor, + y: torch.Tensor, + coords: torch.Tensor, + padding_mask: torch.Tensor = None, + ): + x = self.norm1(x) + y = self.norm2(y) + # cross attention + # setting coords to None disables positional bias + a = self.attn( + x, + y, + y, + coords=coords if self.positional_bias else None, + padding_mask=padding_mask, + ) + + x = x + a + x = x + self.mlp(self.norm3(x)) + + return x + + +# class BidirectionalRelativePositionalAttention(RelativePositionalAttention): +# def forward( +# self, +# query1: torch.Tensor, +# query2: torch.Tensor, +# coords: torch.Tensor, +# padding_mask: torch.Tensor = None, +# ): +# B, N, D = query1.size() +# q1 = self.q_pro(query1) # (B, N, D) +# q2 = self.q_pro(query2) # (B, N, D) +# v1 = self.v_pro(query1) # (B, N, D) +# v2 = self.v_pro(query2) # (B, N, D) + +# # (B, nh, N, hs) +# q1 = q1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) +# v1 = v1.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) +# q2 = q2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) +# v2 = v2.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) + +# attn_mask = torch.zeros( +# (B, self.n_head, N, N), device=query1.device, dtype=q1.dtype +# ) + +# # add negative value but not too large to keep mixed precision loss from becoming nan +# attn_ignore_val = -1e3 + +# # spatial cutoff +# yx = coords[..., 1:] +# spatial_dist = torch.cdist(yx, yx) +# spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1) +# attn_mask.masked_fill_(spatial_mask, attn_ignore_val) + +# # dont add positional bias to self-attention if coords is None +# if coords is not None: +# if self._mode == "bias": +# attn_mask = attn_mask + self.pos_bias(coords) +# elif self._mode == "rope": +# q1, q2 = self.rot_pos_enc(q1, q2, coords) +# else: +# pass + +# dist = torch.cdist(coords, coords, p=2) +# attn_mask += torch.exp(-0.1 * dist.unsqueeze(1)) + +# # if given key_padding_mask = (B,N) then ignore those tokens (e.g. padding tokens) +# if padding_mask is not None: +# ignore_mask = torch.logical_or( +# padding_mask.unsqueeze(1), padding_mask.unsqueeze(2) +# ).unsqueeze(1) +# attn_mask.masked_fill_(ignore_mask, attn_ignore_val) + +# self.attn_mask = attn_mask.clone() + +# y1 = nn.functional.scaled_dot_product_attention( +# q1, +# q2, +# v1, +# attn_mask=attn_mask, +# dropout_p=self.dropout if self.training else 0, +# ) +# y2 = nn.functional.scaled_dot_product_attention( +# q2, +# q1, +# v2, +# attn_mask=attn_mask, +# dropout_p=self.dropout if self.training else 0, +# ) + +# y1 = y1.transpose(1, 2).contiguous().view(B, N, D) +# y1 = self.proj(y1) +# y2 = y2.transpose(1, 2).contiguous().view(B, N, D) +# y2 = self.proj(y2) +# return y1, y2 + + +# class BidirectionalCrossAttention(nn.Module): +# def __init__( +# self, +# coord_dim: int = 2, +# d_model=256, +# num_heads=4, +# dropout=0.1, +# window: int = 16, +# cutoff_spatial: int = 256, +# positional_bias: Literal["bias", "rope", "none"] = "bias", +# positional_bias_n_spatial: int = 32, +# ): +# super().__init__() +# self.positional_bias = positional_bias +# self.attn = BidirectionalRelativePositionalAttention( +# coord_dim, +# d_model, +# num_heads, +# cutoff_spatial=cutoff_spatial, +# n_spatial=positional_bias_n_spatial, +# cutoff_temporal=window, +# n_temporal=window, +# dropout=dropout, +# mode=positional_bias, +# ) + +# self.mlp = FeedForward(d_model) +# self.norm1 = nn.LayerNorm(d_model) +# self.norm2 = nn.LayerNorm(d_model) + +# def forward( +# self, +# x: torch.Tensor, +# y: torch.Tensor, +# coords: torch.Tensor, +# padding_mask: torch.Tensor = None, +# ): +# x = self.norm1(x) +# y = self.norm1(y) + +# # cross attention +# # setting coords to None disables positional bias +# x2, y2 = self.attn( +# x, +# y, +# coords=coords if self.positional_bias else None, +# padding_mask=padding_mask, +# ) +# # print(torch.norm(x2).item()/torch.norm(x).item()) +# x = x + x2 +# x = x + self.mlp(self.norm2(x)) +# y = y + y2 +# y = y + self.mlp(self.norm2(y)) + +# return x, y + + +class TrackingTransformer(torch.nn.Module): + def __init__( + self, + coord_dim: int = 3, + feat_dim: int = 0, + d_model: int = 128, + nhead: int = 4, + num_encoder_layers: int = 4, + num_decoder_layers: int = 4, + dropout: float = 0.1, + pos_embed_per_dim: int = 32, + feat_embed_per_dim: int = 1, + window: int = 6, + spatial_pos_cutoff: int = 256, + attn_positional_bias: Literal["bias", "rope", "none"] = "rope", + attn_positional_bias_n_spatial: int = 16, + causal_norm: Literal[ + "none", "linear", "softmax", "quiet_softmax" + ] = "quiet_softmax", + attn_dist_mode: str = "v0", + ): + super().__init__() + + self.config = dict( + coord_dim=coord_dim, + feat_dim=feat_dim, + pos_embed_per_dim=pos_embed_per_dim, + d_model=d_model, + nhead=nhead, + num_encoder_layers=num_encoder_layers, + num_decoder_layers=num_decoder_layers, + window=window, + dropout=dropout, + attn_positional_bias=attn_positional_bias, + attn_positional_bias_n_spatial=attn_positional_bias_n_spatial, + spatial_pos_cutoff=spatial_pos_cutoff, + feat_embed_per_dim=feat_embed_per_dim, + causal_norm=causal_norm, + attn_dist_mode=attn_dist_mode, + ) + + # TODO remove, alredy present in self.config + # self.window = window + # self.feat_dim = feat_dim + # self.coord_dim = coord_dim + + self.proj = nn.Linear( + (1 + coord_dim) * pos_embed_per_dim + feat_dim * feat_embed_per_dim, d_model + ) + self.norm = nn.LayerNorm(d_model) + + self.encoder = nn.ModuleList([ + EncoderLayer( + coord_dim, + d_model, + nhead, + dropout, + window=window, + cutoff_spatial=spatial_pos_cutoff, + positional_bias=attn_positional_bias, + positional_bias_n_spatial=attn_positional_bias_n_spatial, + attn_dist_mode=attn_dist_mode, + ) + for _ in range(num_encoder_layers) + ]) + self.decoder = nn.ModuleList([ + DecoderLayer( + coord_dim, + d_model, + nhead, + dropout, + window=window, + cutoff_spatial=spatial_pos_cutoff, + positional_bias=attn_positional_bias, + positional_bias_n_spatial=attn_positional_bias_n_spatial, + attn_dist_mode=attn_dist_mode, + ) + for _ in range(num_decoder_layers) + ]) + + self.head_x = FeedForward(d_model) + self.head_y = FeedForward(d_model) + + if feat_embed_per_dim > 1: + self.feat_embed = PositionalEncoding( + cutoffs=(1000,) * feat_dim, + n_pos=(feat_embed_per_dim,) * feat_dim, + cutoffs_start=(0.01,) * feat_dim, + ) + else: + self.feat_embed = nn.Identity() + + self.pos_embed = PositionalEncoding( + cutoffs=(window,) + (spatial_pos_cutoff,) * coord_dim, + n_pos=(pos_embed_per_dim,) * (1 + coord_dim), + ) + + # self.pos_embed = NoPositionalEncoding(d=pos_embed_per_dim * (1 + coord_dim)) + + # @profile + def forward(self, coords, features=None, padding_mask=None, attn_feat=None): + assert coords.ndim == 3 and coords.shape[-1] in (3, 4) + _B, _N, _D = coords.shape + + # disable padded coords (such that it doesnt affect minimum) + if padding_mask is not None: + coords = coords.clone() + coords[padding_mask] = coords.max() + + # remove temporal offset + min_time = coords[:, :, :1].min(dim=1, keepdims=True).values + coords = coords - min_time + + pos = self.pos_embed(coords) + + if features is None or features.numel() == 0: + features = pos + else: + features = self.feat_embed(features) + features = torch.cat((pos, features), axis=-1) + + features = self.proj(features) + if attn_feat is not None: + # add attention embedding + features = features + attn_feat + + features = self.norm(features) + + x = features + + # encoder + for enc in self.encoder: + x = enc(x, coords=coords, padding_mask=padding_mask) + + y = features + # decoder w cross attention + for dec in self.decoder: + y = dec(y, x, coords=coords, padding_mask=padding_mask) + # y = dec(y, y, coords=coords, padding_mask=padding_mask) + + x = self.head_x(x) + y = self.head_y(y) + + # outer product is the association matrix (logits) + A = torch.einsum("bnd,bmd->bnm", x, y) + + return A + + def normalize_output( + self, + A: torch.FloatTensor, + timepoints: torch.LongTensor, + coords: torch.FloatTensor, + ) -> torch.FloatTensor: + """Apply (parental) softmax, or elementwise sigmoid. + + Args: + A: Tensor of shape B, N, N + timepoints: Tensor of shape B, N + coords: Tensor of shape B, N, (time + n_spatial) + """ + assert A.ndim == 3 + assert timepoints.ndim == 2 + assert coords.ndim == 3 + assert coords.shape[2] == 1 + self.config["coord_dim"] + + # spatial distances + dist = torch.cdist(coords[:, :, 1:], coords[:, :, 1:]) + invalid = dist > self.config["spatial_pos_cutoff"] + + if self.config["causal_norm"] == "none": + # Spatially distant entries are set to zero + A = torch.sigmoid(A) + A[invalid] = 0 + else: + return torch.stack([ + blockwise_causal_norm( + _A, _t, mode=self.config["causal_norm"], mask_invalid=_m + ) + for _A, _t, _m in zip(A, timepoints, invalid) + ]) + return A + + def save(self, folder): + folder = Path(folder) + folder.mkdir(parents=True, exist_ok=True) + yaml.safe_dump(self.config, open(folder / "config.yaml", "w")) + torch.save(self.state_dict(), folder / "model.pt") + + @classmethod + def from_folder( + cls, folder, map_location=None, args=None, checkpoint_path: str = "model.pt" + ): + folder = Path(folder) + + config = yaml.load(open(folder / "config.yaml"), Loader=yaml.FullLoader) + if args: + args = vars(args) + for k, v in config.items(): + errors = [] + if k in args: + if config[k] != args[k]: + errors.append( + f"Loaded model config {k}={config[k]}, but current argument" + f" {k}={args[k]}." + ) + if errors: + raise ValueError("\n".join(errors)) + + model = cls(**config) + + # try: + # # Try to load from lightning checkpoint first + # v_folder = sorted((folder / "tb").glob("version_*"))[version] + # checkpoint = sorted((v_folder / "checkpoints").glob("*epoch*.ckpt"))[0] + # pl_state_dict = torch.load(checkpoint, map_location=map_location)[ + # "state_dict" + # ] + # state_dict = OrderedDict() + + # # Hack + # for k, v in pl_state_dict.items(): + # if k.startswith("model."): + # state_dict[k[6:]] = v + # else: + # raise ValueError(f"Unexpected key {k} in state_dict") + + # model.load_state_dict(state_dict) + # logger.info(f"Loaded model from {checkpoint}") + # except: + # # Default: Load manually saved model (legacy) + + fpath = folder / checkpoint_path + logger.info(f"Loading model state from {fpath}") + + state = torch.load(fpath, map_location=map_location, weights_only=True) + # if state is a checkpoint, we have to extract state_dict + if "state_dict" in state: + state = state["state_dict"] + state = OrderedDict( + (k[6:], v) for k, v in state.items() if k.startswith("model.") + ) + model.load_state_dict(state) + + return model + + @classmethod + def from_cfg( + cls, cfg_path, args=None + ): + + cfg_path = Path(cfg_path) + + config = yaml.load(open(cfg_path), Loader=yaml.FullLoader) + if args: + args = vars(args) + for k, v in config.items(): + errors = [] + if k in args: + if config[k] != args[k]: + errors.append( + f"Loaded model config {k}={config[k]}, but current argument" + f" {k}={args[k]}." + ) + if errors: + raise ValueError("\n".join(errors)) + + model = cls(**config) + + return model diff --git a/models/tra_post_model/trackastra/model/model_api.py b/models/tra_post_model/trackastra/model/model_api.py new file mode 100644 index 0000000000000000000000000000000000000000..5765e8d5dea25a0031332d0e7023091a3e08776a --- /dev/null +++ b/models/tra_post_model/trackastra/model/model_api.py @@ -0,0 +1,338 @@ +import logging +import os +from pathlib import Path +from typing import Literal, Union, Optional, Tuple + +import dask.array as da +import numpy as np +import tifffile +import torch +import yaml +from tqdm import tqdm + +from ..data import build_windows, get_features, load_tiff_timeseries +from ..tracking import TrackGraph, build_graph, track_greedy +from ..utils import normalize +from .model import TrackingTransformer +from .predict import predict_windows +from .pretrained import download_pretrained + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Trackastra: + """A transformer-based tracking model for time-lapse data. + + Trackastra links segmented objects across time frames by predicting + associations with a transformer model trained on diverse time-lapse videos. + + The model takes as input: + - A sequence of images of shape (T,(Z),Y,X) + - Corresponding instance segmentation masks of shape (T,(Z),Y,X) + + It supports multiple tracking modes: + - greedy_nodiv: Fast greedy linking without division + - greedy: Fast greedy linking with division + - ilp: Integer Linear Programming based linking (more accurate but slower) + + Examples: + >>> # Load example data + >>> from trackastra.data import example_data_bacteria + >>> imgs, masks = example_data_bacteria() + >>> + >>> # Load pretrained model and track + >>> model = Trackastra.from_pretrained("general_2d", device="cuda") + >>> track_graph = model.track(imgs, masks, mode="greedy") + """ + + def __init__( + self, + transformer: TrackingTransformer, + train_args: dict, + device: Literal["cuda", "mps", "cpu", "automatic", None] = None, + ): + """Initialize Trackastra model. + + Args: + transformer: The underlying transformer model. + train_args: Training configuration arguments. + device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). + """ + if device == "cuda": + if torch.cuda.is_available(): + self.device = "cuda" + else: + logger.info("Cuda not available, falling back to cpu.") + self.device = "cpu" + elif device == "mps": + if ( + torch.backends.mps.is_available() + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" + ): + self.device = "mps" + else: + logger.info("Mps not available, falling back to cpu.") + self.device = "cpu" + elif device == "cpu": + self.device = "cpu" + elif device == "automatic" or device is None: + should_use_mps = ( + torch.backends.mps.is_available() + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" + ) + self.device = ( + "cuda" + if torch.cuda.is_available() + else ( + "mps" + if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") + else "cpu" + ) + ) + else: + raise ValueError(f"Device {device} not recognized.") + + logger.info(f"Using device {self.device}") + + self.transformer = transformer.to(self.device) + self.train_args = train_args + + @classmethod + def from_folder(cls, dir: Union[Path, str], device: Optional[str] = None): + """Load a Trackastra model from a local folder. + + Args: + dir: Path to model folder containing: + - model weights + - train_config.yaml with training arguments + device: Device to run model on. + + Returns: + Trackastra model instance. + """ + # Always load to cpu first + transformer = TrackingTransformer.from_folder( + Path(dir).expanduser(), map_location="cpu" + ) + train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader) + return cls(transformer=transformer, train_args=train_args, device=device) + + @classmethod + def from_pretrained( + cls, name: str, device: Optional[str] = None, download_dir: Optional[Path] = None + ): + """Load a pretrained Trackastra model. + + Available pretrained models are described in detail in pretrained.json. + + Args: + name: Name of pretrained model (e.g. "general_2d"). + device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). + download_dir: Directory to download model to (defaults to ~/.cache/trackastra). + + Returns: + Trackastra model instance. + """ + folder = download_pretrained(name, download_dir) + # download zip from github to location/name, then unzip + return cls.from_folder(folder, device=device) + + def _predict( + self, + imgs: Union[np.ndarray, da.Array], + masks: Union[np.ndarray, da.Array], + edge_threshold: float = 0.05, + n_workers: int = 0, + normalize_imgs: bool = True, + progbar_class=tqdm, + ): + logger.info("Predicting weights for candidate graph") + if normalize_imgs: + if isinstance(imgs, da.Array): + imgs = imgs.map_blocks(normalize) + else: + imgs = normalize(imgs) + + self.transformer.eval() + + features = get_features( + detections=masks, + imgs=imgs, + ndim=self.transformer.config["coord_dim"], + n_workers=n_workers, + progbar_class=progbar_class, + ) + logger.info("Building windows") + windows = build_windows( + features, + window_size=self.transformer.config["window"], + progbar_class=progbar_class, + ) + + logger.info("Predicting windows") + predictions = predict_windows( + windows=windows, + features=features, + model=self.transformer, + edge_threshold=edge_threshold, + spatial_dim=masks.ndim - 1, + progbar_class=progbar_class, + ) + + return predictions + + def _track_from_predictions( + self, + predictions, + mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", + use_distance: bool = False, + max_distance: int = 256, + max_neighbors: int = 10, + delta_t: int = 1, + **kwargs, + ): + logger.info("Running greedy tracker") + nodes = predictions["nodes"] + weights = predictions["weights"] + + candidate_graph = build_graph( + nodes=nodes, + weights=weights, + use_distance=use_distance, + max_distance=max_distance, + max_neighbors=max_neighbors, + delta_t=delta_t, + ) + if mode == "greedy": + return track_greedy(candidate_graph) + elif mode == "greedy_nodiv": + return track_greedy(candidate_graph, allow_divisions=False) + elif mode == "ilp": + from trackastra.tracking.ilp import track_ilp + + return track_ilp(candidate_graph, ilp_config="gt", **kwargs) + else: + raise ValueError(f"Tracking mode {mode} does not exist.") + + def track( + self, + imgs: Union[np.ndarray, da.Array], + masks: Union[np.ndarray, da.Array], + mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", + normalize_imgs: bool = True, + progbar_class=tqdm, + n_workers: int = 0, + **kwargs, + ) -> TrackGraph: + """Track objects across time frames. + + This method links segmented objects across time frames using the specified + tracking mode. No hyperparameters need to be chosen beyond the tracking mode. + + Args: + imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array) + masks: Instance segmentation masks of shape (T,(Z),Y,X). + mode: Tracking mode: + - "greedy_nodiv": Fast greedy linking without division + - "greedy": Fast greedy linking with division + - "ilp": Integer Linear Programming based linking (more accurate but slower) + progbar_class: Progress bar class to use. + n_workers: Number of worker processes for feature extraction. + normalize_imgs: Whether to normalize the images. + **kwargs: Additional arguments passed to tracking algorithm. + + Returns: + TrackGraph containing the tracking results. + """ + if not imgs.shape == masks.shape: + raise RuntimeError( + f"Img shape {imgs.shape} and mask shape {masks.shape} do not match." + ) + + if not imgs.ndim == self.transformer.config["coord_dim"] + 1: + raise RuntimeError( + f"images should be a sequence of {self.transformer.config['coord_dim']}D images" + ) + + predictions = self._predict( + imgs, + masks, + normalize_imgs=normalize_imgs, + progbar_class=progbar_class, + n_workers=n_workers, + ) + track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs) + return track_graph + + def track_from_disk( + self, + imgs_path: Path, + masks_path: Path, + mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", + normalize_imgs: bool = True, + **kwargs, + ) -> Tuple[TrackGraph, np.ndarray]: + """Track objects directly from image and mask files on disk. + + This method supports both single tiff files and directories + + Args: + imgs_path: Path to input images. Can be: + - Directory containing numbered tiff files of shape (C),(Z),Y,X + - Single tiff file with time series of shape T,(C),(Z),Y,X + masks_path: Path to mask files. Can be: + - Directory containing numbered tiff files of shape (Z),Y,X + - Single tiff file with time series of shape T,(Z),Y,X + mode: Tracking mode: + - "greedy_nodiv": Fast greedy linking without division + - "greedy": Fast greedy linking with division + - "ilp": Integer Linear Programming based linking (more accurate but slower) + normalize_imgs: Whether to normalize the images. + **kwargs: Additional arguments passed to tracking algorithm. + + Returns: + Tuple of (TrackGraph, tracked masks). + """ + if not imgs_path.exists(): + raise FileNotFoundError(f"{imgs_path=} does not exist.") + if not masks_path.exists(): + raise FileNotFoundError(f"{masks_path=} does not exist.") + + if imgs_path.is_dir(): + imgs = load_tiff_timeseries(imgs_path) + else: + imgs = tifffile.imread(imgs_path) + + if masks_path.is_dir(): + masks = load_tiff_timeseries(masks_path) + else: + masks = tifffile.imread(masks_path) + + if len(imgs) != len(masks): + raise RuntimeError( + f"#imgs and #masks do not match. Found {len(imgs)} images," + f" {len(masks)} masks." + ) + + if imgs.ndim - 1 == masks.ndim: + if imgs[1] == 1: + logger.info( + "Found a channel dimension with a single channel. Removing dim." + ) + masks = np.squeeze(masks, 1) + else: + raise RuntimeError( + "Trackastra currently only supports single channel images." + ) + + if imgs.shape != masks.shape: + raise RuntimeError( + f"Img shape {imgs.shape} and mask shape {masks.shape} do not match." + ) + + return self.track( + imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs + ), masks diff --git a/models/tra_post_model/trackastra/model/model_parts.py b/models/tra_post_model/trackastra/model/model_parts.py new file mode 100644 index 0000000000000000000000000000000000000000..63e27a0a0566002155eeec9d1566b565be5b77ab --- /dev/null +++ b/models/tra_post_model/trackastra/model/model_parts.py @@ -0,0 +1,287 @@ +"""Transformer class.""" + +import logging +import math +from typing import Literal + +import torch +import torch.nn.functional as F +from torch import nn + +from .rope import RotaryPositionalEncoding + +from typing import Tuple + +logger = logging.getLogger(__name__) + + +def _pos_embed_fourier1d_init( + cutoff: float = 256, n: int = 32, cutoff_start: float = 1 +): + return ( + torch.exp(torch.linspace(-math.log(cutoff_start), -math.log(cutoff), n)) + .unsqueeze(0) + .unsqueeze(0) + ) + + +class FeedForward(nn.Module): + def __init__(self, d_model, expand: float = 2, bias: bool = True): + super().__init__() + self.fc1 = nn.Linear(d_model, int(d_model * expand)) + self.fc2 = nn.Linear(int(d_model * expand), d_model, bias=bias) + self.act = nn.GELU() + + def forward(self, x): + return self.fc2(self.act(self.fc1(x))) + + +class PositionalEncoding(nn.Module): + def __init__( + self, + cutoffs: Tuple[float] = (256,), + n_pos: Tuple[int] = (32,), + cutoffs_start=None, + ): + """Positional encoding with given cutoff and number of frequencies for each dimension. + number of dimension is inferred from the length of cutoffs and n_pos. + """ + super().__init__() + if cutoffs_start is None: + cutoffs_start = (1,) * len(cutoffs) + + assert len(cutoffs) == len(n_pos) + self.freqs = nn.ParameterList([ + nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2)) + for cutoff, n, cutoff_start in zip(cutoffs, n_pos, cutoffs_start) + ]) + + def forward(self, coords: torch.Tensor): + _B, _N, D = coords.shape + assert D == len(self.freqs) + embed = torch.cat( + tuple( + torch.cat( + ( + torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq), + torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq), + ), + axis=-1, + ) + / math.sqrt(len(freq)) + for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) + ), + axis=-1, + ) + + return embed + + +class NoPositionalEncoding(nn.Module): + def __init__(self, d): + """One learnable input token that ignores positional information.""" + super().__init__() + self.d = d + # self.token = nn.Parameter(torch.randn(d)) + + def forward(self, coords: torch.Tensor): + B, N, _ = coords.shape + return ( + # torch.ones((B, N, self.d), device=coords.device) * 0.1 + # torch.randn((1, 1, self.d), device=coords.device).expand(B, N, -1) * 0.01 + torch.randn((B, N, self.d), device=coords.device) * 0.01 + + torch.randn((1, 1, self.d), device=coords.device).expand(B, N, -1) * 0.1 + ) + # return self.token.view(1, 1, -1).expand(B, N, -1) + + +def _bin_init_exp(cutoff: float, n: int): + return torch.exp(torch.linspace(0, math.log(cutoff + 1), n)) + + +def _bin_init_linear(cutoff: float, n: int): + return torch.linspace(-cutoff, cutoff, n) + + +class RelativePositionalBias(nn.Module): + def __init__( + self, + n_head: int, + cutoff_spatial: float, + cutoff_temporal: float, + n_spatial: int = 32, + n_temporal: int = 16, + ): + """Learnt relative positional bias to add to self-attention matrix. + + Spatial bins are exponentially spaced, temporal bins are linearly spaced. + + Args: + n_head (int): Number of pos bias heads. Equal to number of attention heads + cutoff_spatial (float): Maximum distance in space. + cutoff_temporal (float): Maxium distance in time. Equal to window size of transformer. + n_spatial (int, optional): Number of spatial bins. + n_temporal (int, optional): Number of temporal bins in each direction. Should be equal to window size. Total = 2 * n_temporal + 1. Defaults to 16. + """ + super().__init__() + self._spatial_bins = _bin_init_exp(cutoff_spatial, n_spatial) + self._temporal_bins = _bin_init_linear(cutoff_temporal, 2 * n_temporal + 1) + self.register_buffer("spatial_bins", self._spatial_bins) + self.register_buffer("temporal_bins", self._temporal_bins) + self.n_spatial = n_spatial + self.n_head = n_head + self.bias = nn.Parameter( + -0.5 + torch.rand((2 * n_temporal + 1) * n_spatial, n_head) + ) + + def forward(self, coords: torch.Tensor): + _B, _N, _D = coords.shape + t = coords[..., 0] + yx = coords[..., 1:] + temporal_dist = t.unsqueeze(-1) - t.unsqueeze(-2) + spatial_dist = torch.cdist(yx, yx) + + spatial_idx = torch.bucketize(spatial_dist, self.spatial_bins) + torch.clamp_(spatial_idx, max=len(self.spatial_bins) - 1) + temporal_idx = torch.bucketize(temporal_dist, self.temporal_bins) + torch.clamp_(temporal_idx, max=len(self.temporal_bins) - 1) + + # do some index gymnastics such that backward is not super slow + # https://discuss.pytorch.org/t/how-to-select-multiple-indexes-over-multiple-dimensions-at-the-same-time/98532/2 + idx = spatial_idx.flatten() + temporal_idx.flatten() * self.n_spatial + bias = self.bias.index_select(0, idx).view((*spatial_idx.shape, self.n_head)) + # -> B, nH, N, N + bias = bias.transpose(-1, 1) + return bias + + +class RelativePositionalAttention(nn.Module): + def __init__( + self, + coord_dim: int, + embed_dim: int, + n_head: int, + cutoff_spatial: float = 256, + cutoff_temporal: float = 16, + n_spatial: int = 32, + n_temporal: int = 16, + dropout: float = 0.0, + mode: Literal["bias", "rope", "none"] = "bias", + attn_dist_mode: str = "v0", + ): + super().__init__() + + if not embed_dim % (2 * n_head) == 0: + raise ValueError( + f"embed_dim {embed_dim} must be divisible by 2 times n_head {2 * n_head}" + ) + + # qkv projection + self.q_pro = nn.Linear(embed_dim, embed_dim, bias=True) + self.k_pro = nn.Linear(embed_dim, embed_dim, bias=True) + self.v_pro = nn.Linear(embed_dim, embed_dim, bias=True) + + # output projection + self.proj = nn.Linear(embed_dim, embed_dim) + # regularization + self.dropout = dropout + self.n_head = n_head + self.embed_dim = embed_dim + self.cutoff_spatial = cutoff_spatial + self.attn_dist_mode = attn_dist_mode + + if mode == "bias" or mode is True: + self.pos_bias = RelativePositionalBias( + n_head=n_head, + cutoff_spatial=cutoff_spatial, + cutoff_temporal=cutoff_temporal, + n_spatial=n_spatial, + n_temporal=n_temporal, + ) + elif mode == "rope": + # each part needs to be divisible by 2 + n_split = 2 * (embed_dim // (2 * (coord_dim + 1) * n_head)) + + self.rot_pos_enc = RotaryPositionalEncoding( + cutoffs=((cutoff_temporal,) + (cutoff_spatial,) * coord_dim), + n_pos=(embed_dim // n_head - coord_dim * n_split,) + + (n_split,) * coord_dim, + ) + elif mode == "none": + pass + elif mode is None or mode is False: + logger.warning( + "attn_positional_bias is not set (None or False), no positional bias." + ) + pass + else: + raise ValueError(f"Unknown mode {mode}") + + self._mode = mode + + def forward( + self, + query: torch.Tensor, + key: torch.Tensor, + value: torch.Tensor, + coords: torch.Tensor, + padding_mask: torch.Tensor = None, + ): + B, N, D = query.size() + q = self.q_pro(query) # (B, N, D) + k = self.k_pro(key) # (B, N, D) + v = self.v_pro(value) # (B, N, D) + # (B, nh, N, hs) + k = k.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) + q = q.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) + v = v.view(B, N, self.n_head, D // self.n_head).transpose(1, 2) + + attn_mask = torch.zeros( + (B, self.n_head, N, N), device=query.device, dtype=q.dtype + ) + + # add negative value but not too large to keep mixed precision loss from becoming nan + attn_ignore_val = -1e3 + + # spatial cutoff + yx = coords[..., 1:] + spatial_dist = torch.cdist(yx, yx) + spatial_mask = (spatial_dist > self.cutoff_spatial).unsqueeze(1) + attn_mask.masked_fill_(spatial_mask, attn_ignore_val) + + # dont add positional bias to self-attention if coords is None + if coords is not None: + if self._mode == "bias": + attn_mask = attn_mask + self.pos_bias(coords) + elif self._mode == "rope": + q, k = self.rot_pos_enc(q, k, coords) + else: + pass + + if self.attn_dist_mode == "v0": + dist = torch.cdist(coords, coords, p=2) + attn_mask += torch.exp(-0.1 * dist.unsqueeze(1)) + elif self.attn_dist_mode == "v1": + attn_mask += torch.exp( + -5 * spatial_dist.unsqueeze(1) / self.cutoff_spatial + ) + else: + raise ValueError(f"Unknown attn_dist_mode {self.attn_dist_mode}") + + # if given key_padding_mask = (B,N) then ignore those tokens (e.g. padding tokens) + if padding_mask is not None: + ignore_mask = torch.logical_or( + padding_mask.unsqueeze(1), padding_mask.unsqueeze(2) + ).unsqueeze(1) + attn_mask.masked_fill_(ignore_mask, attn_ignore_val) + + # self.attn_mask = attn_mask.clone() + + y = F.scaled_dot_product_attention( + q, k, v, attn_mask=attn_mask, dropout_p=self.dropout if self.training else 0 + ) + + y = y.transpose(1, 2).contiguous().view(B, N, D) + # output projection + y = self.proj(y) + return y diff --git a/models/tra_post_model/trackastra/model/model_sd.py b/models/tra_post_model/trackastra/model/model_sd.py new file mode 100644 index 0000000000000000000000000000000000000000..5765e8d5dea25a0031332d0e7023091a3e08776a --- /dev/null +++ b/models/tra_post_model/trackastra/model/model_sd.py @@ -0,0 +1,338 @@ +import logging +import os +from pathlib import Path +from typing import Literal, Union, Optional, Tuple + +import dask.array as da +import numpy as np +import tifffile +import torch +import yaml +from tqdm import tqdm + +from ..data import build_windows, get_features, load_tiff_timeseries +from ..tracking import TrackGraph, build_graph, track_greedy +from ..utils import normalize +from .model import TrackingTransformer +from .predict import predict_windows +from .pretrained import download_pretrained + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + + +class Trackastra: + """A transformer-based tracking model for time-lapse data. + + Trackastra links segmented objects across time frames by predicting + associations with a transformer model trained on diverse time-lapse videos. + + The model takes as input: + - A sequence of images of shape (T,(Z),Y,X) + - Corresponding instance segmentation masks of shape (T,(Z),Y,X) + + It supports multiple tracking modes: + - greedy_nodiv: Fast greedy linking without division + - greedy: Fast greedy linking with division + - ilp: Integer Linear Programming based linking (more accurate but slower) + + Examples: + >>> # Load example data + >>> from trackastra.data import example_data_bacteria + >>> imgs, masks = example_data_bacteria() + >>> + >>> # Load pretrained model and track + >>> model = Trackastra.from_pretrained("general_2d", device="cuda") + >>> track_graph = model.track(imgs, masks, mode="greedy") + """ + + def __init__( + self, + transformer: TrackingTransformer, + train_args: dict, + device: Literal["cuda", "mps", "cpu", "automatic", None] = None, + ): + """Initialize Trackastra model. + + Args: + transformer: The underlying transformer model. + train_args: Training configuration arguments. + device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). + """ + if device == "cuda": + if torch.cuda.is_available(): + self.device = "cuda" + else: + logger.info("Cuda not available, falling back to cpu.") + self.device = "cpu" + elif device == "mps": + if ( + torch.backends.mps.is_available() + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" + ): + self.device = "mps" + else: + logger.info("Mps not available, falling back to cpu.") + self.device = "cpu" + elif device == "cpu": + self.device = "cpu" + elif device == "automatic" or device is None: + should_use_mps = ( + torch.backends.mps.is_available() + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") is not None + and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") != "0" + ) + self.device = ( + "cuda" + if torch.cuda.is_available() + else ( + "mps" + if should_use_mps and os.getenv("PYTORCH_ENABLE_MPS_FALLBACK") + else "cpu" + ) + ) + else: + raise ValueError(f"Device {device} not recognized.") + + logger.info(f"Using device {self.device}") + + self.transformer = transformer.to(self.device) + self.train_args = train_args + + @classmethod + def from_folder(cls, dir: Union[Path, str], device: Optional[str] = None): + """Load a Trackastra model from a local folder. + + Args: + dir: Path to model folder containing: + - model weights + - train_config.yaml with training arguments + device: Device to run model on. + + Returns: + Trackastra model instance. + """ + # Always load to cpu first + transformer = TrackingTransformer.from_folder( + Path(dir).expanduser(), map_location="cpu" + ) + train_args = yaml.load(open(dir / "train_config.yaml"), Loader=yaml.FullLoader) + return cls(transformer=transformer, train_args=train_args, device=device) + + @classmethod + def from_pretrained( + cls, name: str, device: Optional[str] = None, download_dir: Optional[Path] = None + ): + """Load a pretrained Trackastra model. + + Available pretrained models are described in detail in pretrained.json. + + Args: + name: Name of pretrained model (e.g. "general_2d"). + device: Device to run model on ("cuda", "mps", "cpu", "automatic" or None). + download_dir: Directory to download model to (defaults to ~/.cache/trackastra). + + Returns: + Trackastra model instance. + """ + folder = download_pretrained(name, download_dir) + # download zip from github to location/name, then unzip + return cls.from_folder(folder, device=device) + + def _predict( + self, + imgs: Union[np.ndarray, da.Array], + masks: Union[np.ndarray, da.Array], + edge_threshold: float = 0.05, + n_workers: int = 0, + normalize_imgs: bool = True, + progbar_class=tqdm, + ): + logger.info("Predicting weights for candidate graph") + if normalize_imgs: + if isinstance(imgs, da.Array): + imgs = imgs.map_blocks(normalize) + else: + imgs = normalize(imgs) + + self.transformer.eval() + + features = get_features( + detections=masks, + imgs=imgs, + ndim=self.transformer.config["coord_dim"], + n_workers=n_workers, + progbar_class=progbar_class, + ) + logger.info("Building windows") + windows = build_windows( + features, + window_size=self.transformer.config["window"], + progbar_class=progbar_class, + ) + + logger.info("Predicting windows") + predictions = predict_windows( + windows=windows, + features=features, + model=self.transformer, + edge_threshold=edge_threshold, + spatial_dim=masks.ndim - 1, + progbar_class=progbar_class, + ) + + return predictions + + def _track_from_predictions( + self, + predictions, + mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", + use_distance: bool = False, + max_distance: int = 256, + max_neighbors: int = 10, + delta_t: int = 1, + **kwargs, + ): + logger.info("Running greedy tracker") + nodes = predictions["nodes"] + weights = predictions["weights"] + + candidate_graph = build_graph( + nodes=nodes, + weights=weights, + use_distance=use_distance, + max_distance=max_distance, + max_neighbors=max_neighbors, + delta_t=delta_t, + ) + if mode == "greedy": + return track_greedy(candidate_graph) + elif mode == "greedy_nodiv": + return track_greedy(candidate_graph, allow_divisions=False) + elif mode == "ilp": + from trackastra.tracking.ilp import track_ilp + + return track_ilp(candidate_graph, ilp_config="gt", **kwargs) + else: + raise ValueError(f"Tracking mode {mode} does not exist.") + + def track( + self, + imgs: Union[np.ndarray, da.Array], + masks: Union[np.ndarray, da.Array], + mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", + normalize_imgs: bool = True, + progbar_class=tqdm, + n_workers: int = 0, + **kwargs, + ) -> TrackGraph: + """Track objects across time frames. + + This method links segmented objects across time frames using the specified + tracking mode. No hyperparameters need to be chosen beyond the tracking mode. + + Args: + imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array) + masks: Instance segmentation masks of shape (T,(Z),Y,X). + mode: Tracking mode: + - "greedy_nodiv": Fast greedy linking without division + - "greedy": Fast greedy linking with division + - "ilp": Integer Linear Programming based linking (more accurate but slower) + progbar_class: Progress bar class to use. + n_workers: Number of worker processes for feature extraction. + normalize_imgs: Whether to normalize the images. + **kwargs: Additional arguments passed to tracking algorithm. + + Returns: + TrackGraph containing the tracking results. + """ + if not imgs.shape == masks.shape: + raise RuntimeError( + f"Img shape {imgs.shape} and mask shape {masks.shape} do not match." + ) + + if not imgs.ndim == self.transformer.config["coord_dim"] + 1: + raise RuntimeError( + f"images should be a sequence of {self.transformer.config['coord_dim']}D images" + ) + + predictions = self._predict( + imgs, + masks, + normalize_imgs=normalize_imgs, + progbar_class=progbar_class, + n_workers=n_workers, + ) + track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs) + return track_graph + + def track_from_disk( + self, + imgs_path: Path, + masks_path: Path, + mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", + normalize_imgs: bool = True, + **kwargs, + ) -> Tuple[TrackGraph, np.ndarray]: + """Track objects directly from image and mask files on disk. + + This method supports both single tiff files and directories + + Args: + imgs_path: Path to input images. Can be: + - Directory containing numbered tiff files of shape (C),(Z),Y,X + - Single tiff file with time series of shape T,(C),(Z),Y,X + masks_path: Path to mask files. Can be: + - Directory containing numbered tiff files of shape (Z),Y,X + - Single tiff file with time series of shape T,(Z),Y,X + mode: Tracking mode: + - "greedy_nodiv": Fast greedy linking without division + - "greedy": Fast greedy linking with division + - "ilp": Integer Linear Programming based linking (more accurate but slower) + normalize_imgs: Whether to normalize the images. + **kwargs: Additional arguments passed to tracking algorithm. + + Returns: + Tuple of (TrackGraph, tracked masks). + """ + if not imgs_path.exists(): + raise FileNotFoundError(f"{imgs_path=} does not exist.") + if not masks_path.exists(): + raise FileNotFoundError(f"{masks_path=} does not exist.") + + if imgs_path.is_dir(): + imgs = load_tiff_timeseries(imgs_path) + else: + imgs = tifffile.imread(imgs_path) + + if masks_path.is_dir(): + masks = load_tiff_timeseries(masks_path) + else: + masks = tifffile.imread(masks_path) + + if len(imgs) != len(masks): + raise RuntimeError( + f"#imgs and #masks do not match. Found {len(imgs)} images," + f" {len(masks)} masks." + ) + + if imgs.ndim - 1 == masks.ndim: + if imgs[1] == 1: + logger.info( + "Found a channel dimension with a single channel. Removing dim." + ) + masks = np.squeeze(masks, 1) + else: + raise RuntimeError( + "Trackastra currently only supports single channel images." + ) + + if imgs.shape != masks.shape: + raise RuntimeError( + f"Img shape {imgs.shape} and mask shape {masks.shape} do not match." + ) + + return self.track( + imgs, masks, mode, normalize_imgs=normalize_imgs, **kwargs + ), masks diff --git a/models/tra_post_model/trackastra/model/predict.py b/models/tra_post_model/trackastra/model/predict.py new file mode 100644 index 0000000000000000000000000000000000000000..871ec14ec803d027b59c5542ea317ffcd6520312 --- /dev/null +++ b/models/tra_post_model/trackastra/model/predict.py @@ -0,0 +1,188 @@ +import logging +import warnings + +import numpy as np +import torch +from scipy.sparse import SparseEfficiencyWarning, csr_array +from tqdm import tqdm +from typing import List + +# TODO fix circular import +# from .model import TrackingTransformer +# from trackastra.data import WRFeatures + +warnings.simplefilter("ignore", SparseEfficiencyWarning) + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +def predict(batch, model): + """Predict association scores between objects in a batch. + + Args: + batch: Dictionary containing: + - features: Object features array + - coords: Object coordinates array + - timepoints: Time points array + model: TrackingTransformer model to use for prediction. + + Returns: + Array of association scores between objects. + """ + feats = torch.from_numpy(batch["features"]) + coords = torch.from_numpy(batch["coords"]) + timepoints = torch.from_numpy(batch["timepoints"]).long() + # Hack that assumes that all parameters of a model are on the same device + device = next(model.parameters()).device + feats = feats.unsqueeze(0).to(device) + timepoints = timepoints.unsqueeze(0).to(device) + coords = coords.unsqueeze(0).to(device) + + # Concat timepoints to coordinates + coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2) + with torch.no_grad(): + A = model(coords, features=feats) + A = model.normalize_output(A, timepoints, coords) + + # # Spatially far entries should not influence the causal normalization + # dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:]) + # invalid = dist > model.config["spatial_pos_cutoff"] + # A[invalid] = -torch.inf + + A = A.squeeze(0).detach().cpu().numpy() + + return A + + +def predict_windows( + windows: List[dict], + # features: list[WRFeatures], + # model: TrackingTransformer, + features: list, + model, + intra_window_weight: float = 0, + delta_t: int = 1, + edge_threshold: float = 0.05, + spatial_dim: int = 3, + progbar_class=tqdm, +) -> dict: + """Predict associations between objects across sliding windows. + + This function processes a sequence of sliding windows to predict associations + between objects across time frames. It handles: + - Object tracking across time + - Weight normalization across windows + - Edge thresholding + - Time-based filtering + + Args: + windows: List of window dictionaries containing: + - timepoints: Array of time points + - labels: Array of object labels + - features: Object features + - coords: Object coordinates + features: List of feature objects containing: + - labels: Object labels + - timepoints: Time points + - coords: Object coordinates + model: TrackingTransformer model to use for prediction. + intra_window_weight: Weight factor for objects in middle of window. Defaults to 0. + delta_t: Maximum time difference between objects to consider. Defaults to 1. + edge_threshold: Minimum association score to consider. Defaults to 0.05. + spatial_dim: Dimensionality of input masks. May be less than model.coord_dim. + progbar_class: Progress bar class to use. Defaults to tqdm. + + Returns: + Dictionary containing: + - nodes: List of node properties (id, coords, time, label) + - weights: Tuple of ((node_i, node_j), weight) pairs + """ + # first get all objects/coords + time_labels_to_id = dict() + node_properties = list() + max_id = np.sum([len(f.labels) for f in features]) + + all_timepoints = np.concatenate([f.timepoints for f in features]) + all_labels = np.concatenate([f.labels for f in features]) + all_coords = np.concatenate([f.coords for f in features]) + all_coords = all_coords[:, -spatial_dim:] + + for i, (t, la, c) in enumerate(zip(all_timepoints, all_labels, all_coords)): + time_labels_to_id[(t, la)] = i + node_properties.append( + dict( + id=i, + coords=tuple(c), + time=t, + # index=ix, + label=la, + ) + ) + + # create assoc matrix between ids + sp_weights, sp_accum = ( + csr_array((max_id, max_id), dtype=np.float32), + csr_array((max_id, max_id), dtype=np.float32), + ) + + for t in progbar_class( + range(len(windows)), + desc="Computing associations", + ): + # This assumes that the samples in the dataset are ordered by time and start at 0. + batch = windows[t] + timepoints = batch["timepoints"] + labels = batch["labels"] + + A = predict(batch, model) + + dt = timepoints[None, :] - timepoints[:, None] + time_mask = np.logical_and(dt <= delta_t, dt > 0) + A[~time_mask] = 0 + ii, jj = np.where(A >= edge_threshold) + + if len(ii) == 0: + continue + + labels_ii = labels[ii] + labels_jj = labels[jj] + ts_ii = timepoints[ii] + ts_jj = timepoints[jj] + nodes_ii = np.array( + tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_ii, labels_ii)) + ) + nodes_jj = np.array( + tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_jj, labels_jj)) + ) + + # weight middle parts higher + t_middle = t + (model.config["window"] - 1) / 2 + ddt = timepoints[:, None] - t_middle * np.ones_like(dt) + window_weight = np.exp(-intra_window_weight * ddt**2) # default is 1 + # window_weight = np.exp(4*A) # smooth max + sp_weights[nodes_ii, nodes_jj] += window_weight[ii, jj] * A[ii, jj] + sp_accum[nodes_ii, nodes_jj] += window_weight[ii, jj] + + sp_weights_coo = sp_weights.tocoo() + sp_accum_coo = sp_accum.tocoo() + assert np.allclose(sp_weights_coo.col, sp_accum_coo.col) and np.allclose( + sp_weights_coo.row, sp_accum_coo.row + ) + + # Normalize weights by the number of times they were written from different sliding window positions + weights = tuple( + ((i, j), v / a) + for i, j, v, a in zip( + sp_weights_coo.row, + sp_weights_coo.col, + sp_weights_coo.data, + sp_accum_coo.data, + ) + ) + + results = dict() + results["nodes"] = node_properties + results["weights"] = weights + + return results diff --git a/models/tra_post_model/trackastra/model/pretrained.json b/models/tra_post_model/trackastra/model/pretrained.json new file mode 100644 index 0000000000000000000000000000000000000000..a00137d007bd107c41fa5183df0529755cbc53e9 --- /dev/null +++ b/models/tra_post_model/trackastra/model/pretrained.json @@ -0,0 +1,81 @@ +{ + "general_2d": { + "tags": ["cells, nuclei, bacteria, epithelial, yeast, particles"], + "dimensionality": [2], + "description": "For tracking fluorescent nuclei, bacteria (PhC), whole cells (BF, PhC, DIC), epithelial cells with fluorescent membrane, budding yeast cells (PhC), fluorescent particles, .", + "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip", + "datasets": { + "Subset of Cell Tracking Challenge 2d datasets": { + "url": "https://celltrackingchallenge.net/2d-datasets/", + "reference": "Maška M, Ulman V, Delgado-Rodriguez P, Gómez-de-Mariscal E, Nečasová T, Guerrero Peña FA, Ren TI, Meyerowitz EM, Scherr T, Löffler K, Mikut R. The Cell Tracking Challenge: 10 years of objective benchmarking. Nature Methods. 2023 Jul;20(7):1010-20." + }, + "Bacteria van Vliet": { + "url": "https://zenodo.org/records/268921", + "reference": "van Vliet S, Winkler AR, Spriewald S, Stecher B, Ackermann M. Spatially correlated gene expression in bacterial groups: the role of lineage history, spatial gradients, and cell-cell interactions. Cell systems. 2018 Apr 25;6(4):496-507." + }, + "Bacteria ObiWan-Microbi": { + "url": "https://zenodo.org/records/7260137", + "reference": "Seiffarth J, Scherr T, Wollenhaupt B, Neumann O, Scharr H, Kohlheyer D, Mikut R, Nöh K. ObiWan-Microbi: OMERO-based integrated workflow for annotating microbes in the cloud. SoftwareX. 2024 May 1;26:101638." + }, + "Bacteria Persat": { + "url": "https://www.p-lab.science", + "reference": "Datasets kindly provided by Persat lab, EPFL." + }, + "DeepCell": { + "url": "https://datasets.deepcell.org/data", + "reference": "Schwartz, M, Moen E, Miller G, Dougherty T, Borba E, Ding R, Graf W, Pao E, Van Valen D. Caliban: Accurate cell tracking and lineage construction in live-cell imaging experiments with deep learning. Biorxiv. 2023 Sept 13:803205." + }, + "Ker phase contrast": { + "url": "https://osf.io/ysaq2/", + "reference": "Ker DF, Eom S, Sanami S, Bise R, Pascale C, Yin Z, Huh SI, Osuna-Highley E, Junkers SN, Helfrich CJ, Liang PY. Phase contrast time-lapse microscopy datasets with automated and manual cell tracking annotations. Scientific data. 2018 Nov 13;5(1):1-2." + }, + "Epithelia benchmark": { + "reference": "Funke J, Mais L, Champion A, Dye N, Kainmueller D. A benchmark for epithelial cell tracking. InProceedings of The European Conference on Computer Vision (ECCV) Workshops 2018 (pp. 0-0)." + }, + "T Cells": { + "url": "https://zenodo.org/records/5206119" + }, + "Neisseria meningitidis bacterial growth": { + "url": "https://zenodo.org/records/5419619" + }, + "Synthetic nuclei": { + "reference": "Weigert group live cell simulator." + }, + "Synthetic particles": { + "reference": "Weigert group particle simulator." + }, + "Particle Tracking Challenge": { + "url": "http://bioimageanalysis.org/track/#data", + "reference": "Chenouard, N., Smal, I., De Chaumont, F., Maška, M., Sbalzarini, I. F., Gong, Y., ... & Meijering, E. (2014). Objective comparison of particle tracking methods. Nature methods, 11(3), 281-289." + }, + "Yeast Cell-ACDC": { + "url": "https://zenodo.org/records/6795124", + "reference": "Padovani, F., Mairhörmann, B., Falter-Braun, P., Lengefeld, J., & Schmoller, K. M. (2022). Segmentation, tracking and cell cycle analysis of live-cell imaging data with Cell-ACDC. BMC biology, 20(1), 174." + }, + "DeepSea": { + "url": "https://deepseas.org/datasets/", + "reference": "Zargari, A., Lodewijk, G. A., Mashhadi, N., Cook, N., Neudorf, C. W., Araghbidikashani, K., ... & Shariati, S. A. (2023). DeepSea is an efficient deep-learning model for single-cell segmentation and tracking in time-lapse microscopy. Cell Reports Methods, 3(6)." + }, + "Btrack" : { + "url": "https://rdr.ucl.ac.uk/articles/dataset/Cell_tracking_reference_dataset/16595978", + "reference": "Ulicna, K., Vallardi, G., Charras, G., & Lowe, A. R. (2021). Automated deep lineage tree analysis using a Bayesian single cell tracking approach. Frontiers in Computer Science, 3, 734559." + }, + "E. coli in mother machine": { + "url": "https://zenodo.org/records/11237127", + "reference": "O’Connor, O. M., & Dunlop, M. J. (2024). Cell-TRACTR: A transformer-based model for end-to-end segmentation and tracking of cells. bioRxiv, 2024-07." + } + } + }, + "ctc": { + "tags": ["ctc", "Cell Tracking Challenge", "Cell Linking Benchmark"], + "dimensionality": [2, 3], + "description": "For tracking Cell Tracking Challenge datasets. This is the successor of the winning model of the ISBI 2024 CTC generalizable linking challenge.", + "url": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip", + "datasets": { + "All Cell Tracking Challenge 2d+3d datasets with available GT and ERR_SEG": { + "url": "https://celltrackingchallenge.net/3d-datasets/", + "reference": "Maška M, Ulman V, Delgado-Rodriguez P, Gómez-de-Mariscal E, Nečasová T, Guerrero Peña FA, Ren TI, Meyerowitz EM, Scherr T, Löffler K, Mikut R. The Cell Tracking Challenge: 10 years of objective benchmarking. Nature Methods. 2023 Jul;20(7):1010-20." + } + } + } +} \ No newline at end of file diff --git a/models/tra_post_model/trackastra/model/pretrained.py b/models/tra_post_model/trackastra/model/pretrained.py new file mode 100644 index 0000000000000000000000000000000000000000..8bdcbf63375e23618917f19c7061feb32817a371 --- /dev/null +++ b/models/tra_post_model/trackastra/model/pretrained.py @@ -0,0 +1,90 @@ +import logging +import shutil +import tempfile +import zipfile +try: + from importlib.resources import files +except: + from importlib_resources import files +from pathlib import Path + +import requests +from tqdm import tqdm +from typing import Optional + +logger = logging.getLogger(__name__) + +_MODELS = { + "ctc": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/ctc.zip", + "general_2d": "https://github.com/weigertlab/trackastra-models/releases/download/v0.3.0/general_2d.zip", +} + + +def download_and_unzip(url: str, dst: Path): + # TODO make safe and use tempfile lib + if dst.exists(): + print(f"{dst} already downloaded, skipping.") + return + + # get the name of the zipfile + zip_base = Path(url.split("/")[-1]) + + with tempfile.TemporaryDirectory() as tmp: + tmp = Path(tmp) + zip_file = tmp / zip_base + # Download the zip file + download(url, zip_file) + + # Unzip the file + with zipfile.ZipFile(zip_file, "r") as zip_ref: + zip_ref.extractall(tmp) + + shutil.move(tmp / zip_base.stem, dst) + + +def download(url: str, fname: Path): + resp = requests.get(url, stream=True) + total = int(resp.headers.get("content-length", 0)) + # try: + # with (open(str(fname), "wb") as file, + # tqdm( + # desc=str(fname), + # total=total, + # unit="iB", + # unit_scale=True, + # unit_divisor=1024, + # ) as bar,): + # for data in resp.iter_content(chunk_size=1024): + # size = file.write(data) + # bar.update(size) + # except: + with open(str(fname), "wb") as file, tqdm( + desc=str(fname), + total=total, + unit="iB", + unit_scale=True, + unit_divisor=1024, + ) as bar: + for data in resp.iter_content(chunk_size=1024): + size = file.write(data) + bar.update(size) + + +def download_pretrained(name: str, download_dir: Optional[Path] = None): + # TODO make safe, introduce versioning + if download_dir is None: + download_dir = files("trackastra").joinpath(".models") + else: + download_dir = Path(download_dir) + + download_dir.mkdir(exist_ok=True, parents=True) + try: + url = _MODELS[name] + except KeyError: + raise ValueError( + "Pretrained model `name` is not available. Choose from" + f" {list(_MODELS.keys())}" + ) + folder = download_dir / name + download_and_unzip(url=url, dst=folder) + return folder diff --git a/models/tra_post_model/trackastra/model/rope.py b/models/tra_post_model/trackastra/model/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..9ffb57be77f6b797d70d316d83e06ed42ad38420 --- /dev/null +++ b/models/tra_post_model/trackastra/model/rope.py @@ -0,0 +1,94 @@ +"""Transformer class.""" + +# from torch_geometric.nn import GATv2Conv +import math + +import torch +from torch import nn +from typing import Tuple + + +def _pos_embed_fourier1d_init(cutoff: float = 128, n: int = 32): + # Maximum initial frequency is 1 + return torch.exp(torch.linspace(0, -math.log(cutoff), n)).unsqueeze(0).unsqueeze(0) + + +# https://github.com/cvg/LightGlue/blob/b1cd942fc4a3a824b6aedff059d84f5c31c297f6/lightglue/lightglue.py#L51 +def _rotate_half(x: torch.Tensor) -> torch.Tensor: + """Rotate pairs of scalars as 2d vectors by pi/2. + Refer to eq 34 in https://arxiv.org/pdf/2104.09864.pdf. + """ + x = x.unflatten(-1, (-1, 2)) + x1, x2 = x.unbind(dim=-1) + return torch.stack((-x2, x1), dim=-1).flatten(start_dim=-2) + + +class RotaryPositionalEncoding(nn.Module): + def __init__(self, cutoffs: Tuple[float] = (256,), n_pos: Tuple[int] = (32,)): + """Rotary positional encoding with given cutoff and number of frequencies for each dimension. + number of dimension is inferred from the length of cutoffs and n_pos. + + see + https://arxiv.org/pdf/2104.09864.pdf + """ + super().__init__() + assert len(cutoffs) == len(n_pos) + if not all(n % 2 == 0 for n in n_pos): + raise ValueError("n_pos must be even") + + self._n_dim = len(cutoffs) + # theta in RoFormer https://arxiv.org/pdf/2104.09864.pdf + self.freqs = nn.ParameterList([ + nn.Parameter(_pos_embed_fourier1d_init(cutoff, n // 2)) + for cutoff, n in zip(cutoffs, n_pos) + ]) + + def get_co_si(self, coords: torch.Tensor): + _B, _N, D = coords.shape + assert D == len(self.freqs) + co = torch.cat( + tuple( + torch.cos(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) + for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) + ), + axis=-1, + ) + si = torch.cat( + tuple( + torch.sin(0.5 * math.pi * x.unsqueeze(-1) * freq) / math.sqrt(len(freq)) + for x, freq in zip(coords.moveaxis(-1, 0), self.freqs) + ), + axis=-1, + ) + + return co, si + + def forward(self, q: torch.Tensor, k: torch.Tensor, coords: torch.Tensor): + _B, _N, D = coords.shape + _B, _H, _N, _C = q.shape + + if not D == self._n_dim: + raise ValueError(f"coords must have {self._n_dim} dimensions, got {D}") + + co, si = self.get_co_si(coords) + + co = co.unsqueeze(1).repeat_interleave(2, dim=-1) + si = si.unsqueeze(1).repeat_interleave(2, dim=-1) + q2 = q * co + _rotate_half(q) * si + k2 = k * co + _rotate_half(k) * si + + return q2, k2 + + +if __name__ == "__main__": + model = RotaryPositionalEncoding((256, 256), (32, 32)) + + x = 100 * torch.rand(1, 17, 2) + q = torch.rand(1, 4, 17, 64) + k = torch.rand(1, 4, 17, 64) + + q1, k1 = model(q, k, x) + A1 = q1[:, :, 0] @ k1[:, :, 0].transpose(-1, -2) + + q2, k2 = model(q, k, x + 10) + A2 = q2[:, :, 0] @ k2[:, :, 0].transpose(-1, -2) diff --git a/models/tra_post_model/trackastra/tracking/__init__.py b/models/tra_post_model/trackastra/tracking/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6759bd9ec535f870684bae891db815e83b9939b3 --- /dev/null +++ b/models/tra_post_model/trackastra/tracking/__init__.py @@ -0,0 +1,15 @@ +# ruff: noqa: F401 + +from .track_graph import TrackGraph +from .tracking import ( + build_graph, + track_greedy, +) +from .utils import ( + ctc_to_graph, + ctc_to_napari_tracks, + graph_to_ctc, + graph_to_edge_table, + graph_to_napari_tracks, + linear_chains, +) diff --git a/models/tra_post_model/trackastra/tracking/ilp.py b/models/tra_post_model/trackastra/tracking/ilp.py new file mode 100644 index 0000000000000000000000000000000000000000..ca8d186b266fdd665b7fb69430249989bf34d2fb --- /dev/null +++ b/models/tra_post_model/trackastra/tracking/ilp.py @@ -0,0 +1,182 @@ +import logging +import time +from types import SimpleNamespace + +import networkx as nx +import yaml + +try: + import motile +except ModuleNotFoundError: + raise ModuleNotFoundError( + "For tracking with an ILP, please conda install the optional `motile`" + " dependency following https://funkelab.github.io/motile/install.html." + ) + + +logger = logging.getLogger(__name__) + +ILP_CONFIGS = { + "gt": SimpleNamespace( + nodeW=0, + nodeC=-10, # take all nodes + edgeW=-1, + edgeC=0, + appearC=0.25, + disappearC=0.5, + splitC=0.25, + ), + "deepcell_gt": SimpleNamespace( + nodeW=0, + nodeC=-10, # take all nodes + edgeW=-1, + edgeC=0, + appearC=0.25, + disappearC=0.5, + splitC=1, + ), + "deepcell_gt_tuned": SimpleNamespace( + nodeW=0, + nodeC=-10, # take all nodes + edgeW=-1, + edgeC=0, + appearC=0.5, + disappearC=0.5, + splitC=1, + ), + "deepcell_res_tuned": SimpleNamespace( + nodeW=0, + nodeC=0.25, + edgeW=-1, + edgeC=-0.25, + appearC=0.25, + disappearC=0.25, + splitC=1.0, + ), +} + + +def track_ilp( + candidate_graph, + allow_divisions: bool = True, + ilp_config: str = "gt", + params_file: str | None = None, + **kwargs, +): + candidate_graph_motile = motile.TrackGraph(candidate_graph, frame_attribute="time") + + ilp, _used_costs = solve_full_ilp( + candidate_graph_motile, + allow_divisions=allow_divisions, + mode=ilp_config, + params_file=params_file, + ) + print_solution_stats(ilp, candidate_graph_motile) + + graph = solution_to_graph(ilp, candidate_graph_motile) + + return graph + + +def solve_full_ilp( + graph, + allow_divisions: bool, + mode: str, + params_file: str | None, +): + solver = motile.Solver(graph) + if params_file: + with open(params_file) as f: + p = yaml.safe_load(f) + # TODO more checks + p = SimpleNamespace(**p) + logger.info(f"Using ILP parameters {p}") + else: + try: + p = ILP_CONFIGS[mode] + logger.info(f"Using `{mode}` ILP config.") + except KeyError: + raise ValueError( + f"Unknown ILP mode {mode}. Choose from {list(ILP_CONFIGS.keys())} or" + " supply custom parameters via `params_file` argument." + ) + + # Add costs + used_costs = SimpleNamespace() + + # NODES + solver.add_cost( + motile.costs.NodeSelection(weight=p.nodeW, constant=p.nodeC, attribute="weight") + ) + used_costs.nodeW = p.nodeW + used_costs.nodeC = p.nodeC + + # EDGES + solver.add_cost( + motile.costs.EdgeSelection(weight=p.edgeW, constant=p.edgeC, attribute="weight") + ) + used_costs.edgeW = p.edgeW + used_costs.edgeC = p.edgeC + + # APPEAR + solver.add_cost(motile.costs.Appear(constant=p.appearC)) + used_costs.appearC = p.appearC + + # DISAPPEAR + solver.add_cost(motile.costs.Disappear(constant=p.disappearC)) + used_costs.disappearC = p.disappearC + + # DIVISION + if allow_divisions: + solver.add_cost(motile.costs.Split(constant=p.splitC)) + used_costs.splitC = p.splitC + + # Add constraints + solver.add_constraint(motile.constraints.MaxParents(1)) + solver.add_constraint(motile.constraints.MaxChildren(2 if allow_divisions else 1)) + + solver.solve() + + return solver, vars(used_costs) + + +def solution_to_graph(solver, base_graph): + new_graph = nx.DiGraph() + node_indicators = solver.get_variables(motile.variables.NodeSelected) + edge_indicators = solver.get_variables(motile.variables.EdgeSelected) + + # Build nodes + for node, index in node_indicators.items(): + if solver.solution[index] > 0.5: + new_graph.add_node(node, **base_graph.nodes[node]) + + # Build edges + for edge, index in edge_indicators.items(): + if solver.solution[index] > 0.5: + new_graph.add_edge(*edge, **base_graph.edges[edge]) + + return new_graph + + +def print_solution_stats(solver, graph, gt_graph=None): + time.sleep(0.1) # to wait for ilpy prints + print( + f"\nCandidate graph\t\t{len(graph.nodes):3} nodes\t{len(graph.edges):3} edges" + ) + if gt_graph: + print( + f"Ground truth graph\t{len(gt_graph.nodes):3}" + f" nodes\t{len(gt_graph.edges):3} edges" + ) + + node_selected = solver.get_variables(motile.variables.NodeSelected) + edge_selected = solver.get_variables(motile.variables.EdgeSelected) + nodes = 0 + for node in graph.nodes: + if solver.solution[node_selected[node]] > 0.5: + nodes += 1 + edges = 0 + for u, v in graph.edges: + if solver.solution[edge_selected[(u, v)]] > 0.5: + edges += 1 + print(f"Solution graph\t\t{nodes:3} nodes\t{edges:3} edges") diff --git a/models/tra_post_model/trackastra/tracking/track_graph.py b/models/tra_post_model/trackastra/tracking/track_graph.py new file mode 100644 index 0000000000000000000000000000000000000000..559f118dcd95a83b9e139a99bfe4179eeba0686d --- /dev/null +++ b/models/tra_post_model/trackastra/tracking/track_graph.py @@ -0,0 +1,164 @@ +"""Adapted from https://github.com/funkelab/motile/blob/05fc67f1763afe806f244d10210fa66daa3dca67/motile/track_graph.py. + +MIT License + +Copyright (c) 2023 Funke lab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +""" + +import logging + +import networkx as nx + +logger = logging.getLogger(__name__) + + +class TrackGraph(nx.DiGraph): + """A :class:`networkx.DiGraph` of objects with positions in time and space, + and inter-frame edges between them. + + Provides a few convenience methods for time series graphs in addition to + all the methods inherited from :class:`networkx.DiGraph`. + + Args: + graph_data (optional): + + Optional graph data to pass to the :class:`networkx.DiGraph` + constructor as ``incoming_graph_data``. This can be used to + populate a track graph with entries from a generic + ``networkx`` graph. + + frame_attribute (``string``, optional): + + The name of the node attribute that corresponds to the frame (i.e., + the time dimension) of the object. Defaults to ``'t'``. + """ + + def __init__(self, graph_data=None, frame_attribute="t"): + super().__init__(incoming_graph_data=graph_data) + + self.frame_attribute = frame_attribute + self._graph_changed = True + + self._update_metadata() + + def prev_edges(self, node): + """Get all edges that point forward into ``node``.""" + return self.in_edges(node) + + def next_edges(self, node): + """Get all edges that point forward out of ``node``.""" + return self.out_edges(node) + + def get_frames(self): + """Get a tuple ``(t_begin, t_end)`` of the first and last frame + (exclusive) this track graph has nodes for. + """ + self._update_metadata() + + return (self.t_begin, self.t_end) + + def nodes_by_frame(self, t): + """Get all nodes in frame ``t``.""" + self._update_metadata() + + if t not in self._nodes_by_frame: + return [] + return self._nodes_by_frame[t] + + def _update_metadata(self): + if not self._graph_changed: + return + + self._graph_changed = False + + if self.number_of_nodes() == 0: + self._nodes_by_frame = {} + self.t_begin = None + self.t_end = None + return + + self._nodes_by_frame = {} + for node, data in self.nodes(data=True): + t = data[self.frame_attribute] + if t not in self._nodes_by_frame: + self._nodes_by_frame[t] = [] + self._nodes_by_frame[t].append(node) + + frames = self._nodes_by_frame.keys() + self.t_begin = min(frames) + self.t_end = max(frames) + 1 + + # ensure edges point forwards in time + for u, v in self.edges: + t_u = self.nodes[u][self.frame_attribute] + t_v = self.nodes[v][self.frame_attribute] + assert t_u < t_v, ( + f"Edge ({u}, {v}) does not point forwards in time, but from " + f"frame {t_u} to {t_v}" + ) + + self._graph_changed = False + + # wrappers around node/edge add/remove methods: + + def add_node(self, n, **attr): + super().add_node(n, **attr) + self._graph_changed = True + + def add_nodes_from(self, nodes, **attr): + super().add_nodes_from(nodes, **attr) + self._graph_changed = True + + def remove_node(self, n): + super().remove_node(n) + self._graph_changed = True + + def remove_nodes_from(self, nodes): + super().remove_nodes_from(nodes) + self._graph_changed = True + + def add_edge(self, u, v, **attr): + super().add_edge(u, v, **attr) + self._graph_changed = True + + def add_edges_from(self, ebunch_to_add, **attr): + super().add_edges_from(ebunch_to_add, **attr) + self._graph_changed = True + + def add_weighted_edges_From(self, ebunch_to_add): + super().add_weighted_edges_From(ebunch_to_add) + self._graph_changed = True + + def remove_edge(self, u, v): + super().remove_edge(u, v) + self._graph_changed = True + + def update(self, edges, nodes): + super().update(edges, nodes) + self._graph_changed = True + + def clear(self): + super().clear() + self._graph_changed = True + + def clear_edges(self): + super().clear_edges() + self._graph_changed = True diff --git a/models/tra_post_model/trackastra/tracking/tracking.py b/models/tra_post_model/trackastra/tracking/tracking.py new file mode 100644 index 0000000000000000000000000000000000000000..4ea914b9b874a62bc4c77048e3825729e24e7f3c --- /dev/null +++ b/models/tra_post_model/trackastra/tracking/tracking.py @@ -0,0 +1,165 @@ +import logging +from itertools import chain + +import networkx as nx +import numpy as np +import scipy +from tqdm import tqdm + +from .track_graph import TrackGraph +from typing import Optional, Tuple + +# from trackastra.tracking import graph_to_napari_tracks, graph_to_ctc + +logger = logging.getLogger(__name__) + + +def copy_edge(edge: tuple, source: nx.DiGraph, target: nx.DiGraph): + if edge[0] not in target.nodes: + target.add_node(edge[0], **source.nodes[edge[0]]) + if edge[1] not in target.nodes: + target.add_node(edge[1], **source.nodes[edge[1]]) + target.add_edge(edge[0], edge[1], **source.edges[(edge[0], edge[1])]) + + +def track_greedy( + candidate_graph, + allow_divisions=True, + threshold=0.5, + edge_attr="weight", +): + """Greedy matching, global. + + Iterates over global edges sorted by weight, and keeps edge if feasible and weight above threshold. + + Args: + allow_divisions (bool, optional): + Whether to model divisions. Defaults to True. + + Returns: + solution_graph: NetworkX graph of tracks + """ + logger.info("Running greedy tracker") + + solution_graph = nx.DiGraph() + + # TODO bring back + # if args.gt_as_dets: + # solution_graph.add_nodes_from(candidate_graph.nodes(data=True)) + + edges = candidate_graph.edges(data=True) + edges = sorted( + edges, + key=lambda edge: edge[2][edge_attr], + reverse=True, + ) + + for edge in tqdm(edges, desc="Greedily matched edges"): + node_in, node_out, features = edge + assert ( + features[edge_attr] <= 1.0 + ), "Edge weights are assumed to be normalized to [0,1]" + # assumes sorted edges + if features[edge_attr] < threshold: + break + # Check whether this edge is a feasible edge to add + # i.e. no fusing + if node_out in solution_graph.nodes and solution_graph.in_degree(node_out) > 0: + # target node already has an incoming edge + continue + if node_in in solution_graph and solution_graph.out_degree(node_in) >= ( + 2 if allow_divisions else 1 + ): + # parent node already has max number of outgoing edges + continue + # otherwise add to solution + copy_edge(edge, candidate_graph, solution_graph) + + # df, masks = graph_to_ctc(solution_graph, masks_original) + # tracks, tracks_graph, _ = graph_to_napari_tracks(solution_graph) + + return solution_graph + # TODO this should all be in a tracker class + # return df, masks, solution_graph, tracks_graph, tracks, candidate_graph + + +def build_graph( + nodes: dict, + weights: Optional[tuple] = None, + use_distance: bool = False, + max_distance: Optional[int] = None, + max_neighbors: Optional[int] = None, + delta_t=1, +): + logger.info(f"Build candidate graph with {delta_t=}") + G = nx.DiGraph() + + for node in nodes: + G.add_node( + node["id"], + time=node["time"], + label=node["label"], + coords=node["coords"], + # index=node["index"], + weight=1, + ) + + if use_distance: + weights = None + if weights: + weights = {w[0]: w[1] for w in weights} + + graph = TrackGraph(G, frame_attribute="time") + frame_pairs = zip( + chain(*[ + list(range(graph.t_begin, graph.t_end - d)) for d in range(1, delta_t + 1) + ]), + chain(*[ + list(range(graph.t_begin + d, graph.t_end)) for d in range(1, delta_t + 1) + ]), + ) + iterator = tqdm( + frame_pairs, + total=(graph.t_end - graph.t_begin) * delta_t, + leave=False, + ) + for t_begin, t_end in iterator: + n_edges_t = len(G.edges) + ni, nj = graph.nodes_by_frame(t_begin), graph.nodes_by_frame(t_end) + pi = [] + for _ni in ni: + pi.append(np.array(G.nodes[_ni]["coords"])) + pi = np.stack(pi) + pj = [] + for _nj in nj: + pj.append(np.array(G.nodes[_nj]["coords"])) + pj = np.stack(pj) + + dists = scipy.spatial.distance.cdist(pi, pj) + + for _i, _ni in enumerate(ni): + inds = np.argsort(dists[_i]) + neighbors = 0 + for _j, _nj in zip(inds, np.array(nj)[inds]): + if max_neighbors and neighbors >= max_neighbors: + break + dist = dists[_i, _j] + if max_distance is None or dist <= max_distance: + if weights is None: + G.add_edge(_ni, _nj, weight=1 - dist / max_distance) + neighbors += 1 + else: + if (_ni, _nj) in weights: + G.add_edge(_ni, _nj, weight=weights[(_ni, _nj)]) + neighbors += 1 + + e_added = len(G.edges) - n_edges_t + if e_added == 0: + logger.warning(f"No candidate edges in frame {t_begin}") + iterator.set_description( + f"{e_added} edges in frame {t_begin} Total edges: {len(G.edges)}" + ) + + logger.info(f"Added {len(G.nodes)} vertices, {len(G.edges)} edges") + + return G diff --git a/models/tra_post_model/trackastra/tracking/utils.py b/models/tra_post_model/trackastra/tracking/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..0871cde77bea9e89280080765e6c11712fb51b78 --- /dev/null +++ b/models/tra_post_model/trackastra/tracking/utils.py @@ -0,0 +1,397 @@ +import logging +from collections import deque +from pathlib import Path + +import networkx as nx +import numpy as np +import pandas as pd +import tifffile +from skimage.measure import regionprops +from tqdm import tqdm +from typing import List, Optional, Tuple + +logger = logging.getLogger(__name__) +logger.setLevel(logging.INFO) + + +class FoundTracks(Exception): + pass + + +def ctc_to_napari_tracks(segmentation: np.ndarray, man_track: pd.DataFrame): + """Convert tracks in CTC format to tracks in napari format. + + Args: + segmentation: Dims time, spatial_0, ... , spatial_n + man_track: columns id, start, end, parent + """ + tracks = [] + for t, frame in tqdm( + enumerate(segmentation), + total=len(segmentation), + leave=False, + desc="Computing centroids", + ): + for r in regionprops(frame): + tracks.append((r.label, t, *r.centroid)) + + tracks_graph = {} + for idx, _, _, parent in tqdm( + man_track.to_numpy(), + desc="Converting CTC to napari tracks", + leave=False, + ): + if parent != 0: + tracks_graph[idx] = [parent] + + return tracks, tracks_graph + + +class CtcTracklet: + def __init__(self, parent: int, nodes: List[int], start_frame: int) -> None: + self.parent = parent + self.nodes = nodes + self.start_frame = start_frame + + def __lt__(self, other): + if self.start_frame < other.start_frame: + return True + if self.start_frame > other.start_frame: + return False + if self.start_frame == other.start_frame: + return self.parent < other.parent + + def __str__(self) -> str: + return f"Tracklet(parent={self.parent}, nodes={self.nodes})" + + def __repr__(self) -> str: + return str(self) + + +def ctc_tracklets(G: nx.DiGraph, frame_attribute: str = "time") -> List[CtcTracklet]: + """Return all CTC tracklets in a graph, i.e. + + - first node after + - a division (out_degree of parent = 2) + - an appearance (in_degree=0) + - a gap closing event (delta_t to parent node > 1) + - inner nodes have in_degree=1 and out_degree=1, delta_t=1 + - last node: + - before a division (out_degree = 2) + - before a disappearance (out_degree = 0) + - before a gap closing event (delta_t to next node > 1) + """ + tracklets = [] + # get all nodes with out_degree == 2 (i.e. parent of a tracklet) + + # Queue of tuples(parent id, start node id) + starts = deque() + starts.extend([ + (p, d) for p in G.nodes for d in G.successors(p) if G.out_degree[p] == 2 + ]) + # set parent = -1 since there is no parent + starts.extend([(-1, n) for n in G.nodes if G.in_degree[n] == 0]) + while starts: + _p, _s = starts.popleft() + nodes = [_s] + # build a tracklet + c = _s + while True: + if G.out_degree[c] > 2: + raise ValueError("More than two daughters!") + if G.out_degree[c] == 2: + break + if G.out_degree[c] == 0: + break + t_c = G.nodes[c][frame_attribute] + suc = next(iter(G.successors(c))) + t_suc = G.nodes[suc][frame_attribute] + if t_suc - t_c > 1: + logger.debug( + f"Gap closing edge from `{c} (t={t_c})` to `{suc} (t={t_suc})`" + ) + starts.append((c, suc)) + break + # Add node to tracklet + c = next(iter(G.successors(c))) + nodes.append(c) + + tracklets.append( + CtcTracklet( + parent=_p, nodes=nodes, start_frame=G.nodes[_s][frame_attribute] + ) + ) + + return tracklets + + +def linear_chains(G: nx.DiGraph): + """Find all linear chains in a tree/graph, i.e. paths that. + + i) either start/end at a node with out_degree>in_degree or and have no internal branches, or + ii) consists of a single node or a single splitting node + + Note that each chain includes its start/end node, i.e. they can be appear in multiple chains. + """ + # get all nodes with out_degree>in_degree (i.e. start of chain) + nodes = tuple(n for n in G.nodes if G.out_degree[n] > G.in_degree[n]) + # single nodes are those that are not starting a linear chain + # single_nodes = tuple(n for n in G.nodes if G.out_degree[n] == G.in_degree[n] == 0) + single_nodes = tuple( + n for n in G.nodes if G.in_degree[n] == 0 and G.out_degree[n] != 1 + ) + + for ni in single_nodes: + yield [ni] + + for ni in nodes: + neighs = tuple(G.neighbors(ni)) + for child in neighs: + path = [ni, child] + while len(childs := tuple(G.neighbors(path[-1]))) == 1: + path.append(childs[0]) + yield path + + +def graph_to_napari_tracks( + graph: nx.DiGraph, + properties: List[str] = [], +): + """Convert a track graph to napari tracks.""" + # each tracklet is a linear chain in the graph + chains = tuple(linear_chains(graph)) + + track_end_to_track_id = dict() + labels = [] + for i, cs in enumerate(chains): + label = i + 1 + labels.append(label) + # if len(cs) == 1: + # print(cs) + # # Non-connected node + # continue + end = cs[-1] + track_end_to_track_id[end] = label + + tracks = [] + tracks_graph = dict() + tracks_props = {p: [] for p in properties} + + for label, cs in tqdm(zip(labels, chains), total=len(chains)): + start = cs[0] + if start in track_end_to_track_id and len(cs) > 1: + tracks_graph[label] = track_end_to_track_id[start] + nodes = cs[1:] + else: + nodes = cs + + for c in nodes: + node = graph.nodes[c] + t = node["time"] + coord = node["coords"] + tracks.append([label, t, *list(coord)]) + + for p in properties: + tracks_props[p].append(node[p]) + + tracks = np.array(tracks) + return tracks, tracks_graph, tracks_props + + +def _check_ctc_df(df: pd.DataFrame, masks: np.ndarray): + """Sanity check of all labels in a CTC dataframe are present in the masks.""" + # Check for empty df + if len(df) == 0 and np.all(masks == 0): + return True + + for t in range(df.t1.min(), df.t1.max()): + sub = df[(df.t1 <= t) & (df.t2 >= t)] + sub_lab = set(sub.label) + # Since we have non-negative integer labels, we can np.bincount instead of np.unique for speedup + masks_lab = set(np.where(np.bincount(masks[t].ravel()))[0]) - {0} + if not sub_lab.issubset(masks_lab): + print(f"Missing labels in masks at t={t}: {sub_lab - masks_lab}") + return False + return True + + +def graph_to_edge_table( + graph: nx.DiGraph, + frame_attribute: str = "time", + edge_attribute: str = "weight", + outpath: Optional[Path] = None, +) -> pd.DataFrame: + """Write edges of a graph to a table. + + The table has columns `source_frame`, `source_label`, `target_frame`, `target_label`, and `weight`. + The first line is a header. The source and target are the labels of the objects in the + input masks in the designated frames (0-indexed). + + Args: + graph: With node attributes `frame_attribute`, `edge_attribute` and 'label'. + frame_attribute: Name of the frame attribute 'graph`. + edge_attribute: Name of the score attribute in `graph`. + outpath: If given, save the edges in CSV file format. + + Returns: + pd.DataFrame: Edges DataFrame with columns ['source_frame', 'source', 'target_frame', 'target', 'weight'] + """ + rows = [] + for edge in graph.edges: + source = graph.nodes[edge[0]] + target = graph.nodes[edge[1]] + + source_label = int(source["label"]) + source_frame = int(source[frame_attribute]) + target_label = int(target["label"]) + target_frame = int(target[frame_attribute]) + weight = float(graph.edges[edge][edge_attribute]) + + rows.append([source_frame, source_label, target_frame, target_label, weight]) + + df = pd.DataFrame( + rows, + columns=[ + "source_frame", + "source_label", + "target_frame", + "target_label", + "weight", + ], + ) + df = df.sort_values( + by=["source_frame", "source_label", "target_frame", "target_label"], + ascending=True, + ) + + if outpath is not None: + outpath = Path(outpath) + outpath.parent.mkdir( + parents=True, + exist_ok=True, + ) + + df.to_csv(outpath, index=False, header=True, sep=",") + + return df + + +def graph_to_ctc( + graph: nx.DiGraph, + masks_original: np.ndarray, + check: bool = True, + frame_attribute: str = "time", + outdir: Optional[Path] = None, +) -> Tuple[pd.DataFrame, np.ndarray]: + """Convert graph to ctc track Dataframe and relabeled masks. + + Args: + graph: with node attributes `frame_attribute` and "label" + masks_original: list of masks with unique labels + check: Check CTC format + frame_attribute: Name of the frame attribute in the graph nodes. + outdir: path to save results in CTC format. + + Returns: + pd.DataFrame: track dataframe with columns ['track_id', 't_start', 't_end', 'parent_id'] + np.ndarray: masks with unique color for each track + """ + # each tracklet is a linear chain in the graph + tracklets = ctc_tracklets(graph, frame_attribute=frame_attribute) + + regions = tuple( + dict((reg.label, reg.slice) for reg in regionprops(m)) + for t, m in enumerate(masks_original) + ) + + masks = np.stack([np.zeros_like(m) for m in masks_original]) + rows = [] + # To map parent references to tracklet ids. -1 means no parent, which is mapped to 0 in CTC format. + node_to_tracklets = dict({-1: 0}) + + # Sort tracklets by parent id + for i, _tracklet in tqdm( + enumerate(sorted(tracklets)), + total=len(tracklets), + desc="Converting graph to CTC results", + ): + _parent = _tracklet.parent + _nodes = _tracklet.nodes + label = i + 1 + + _start, end = _nodes[0], _nodes[-1] + + t1 = _tracklet.start_frame + # t1 = graph.nodes[start][frame_attribute] + t2 = graph.nodes[end][frame_attribute] + + node_to_tracklets[end] = label + + # relabel masks + for _n in _nodes: + node = graph.nodes[_n] + t = node[frame_attribute] + lab = node["label"] + ss = regions[t][lab] + m = masks_original[t][ss] == lab + if masks[t][ss][m].max() > 0: + raise RuntimeError(f"Overlapping masks at t={t}, label={lab}") + if np.count_nonzero(m) == 0: + raise RuntimeError(f"Empty mask at t={t}, label={lab}") + masks[t][ss][m] = label + + rows.append([label, t1, t2, node_to_tracklets[_parent]]) + + df = pd.DataFrame(rows, columns=["label", "t1", "t2", "parent"], dtype=int) + + masks = np.stack(masks) + + if check: + _check_ctc_df(df, masks) + + if outdir is not None: + outdir = Path(outdir) + outdir.mkdir( + # mode=775, + parents=True, + exist_ok=True, + ) + df.to_csv(outdir / "res_track.txt", index=False, header=False, sep=" ") + for i, m in tqdm(enumerate(masks), total=len(masks), desc="Saving masks"): + tifffile.imwrite( + outdir / f"res_track{i:04d}.tif", + m, + compression="zstd", + ) + + return df, masks + + +def ctc_to_graph(df: pd.DataFrame, frame_attribute: str = "time"): + """From a ctc dataframe, create a digraph with frame_attribute and label as node attributes. + + Args: + df: pd.DataFrame with columns `label`, `t1`, `t2`, `parent` (man_track.txt) + frame_attribute: Name of the frame attribute in the graph nodes. + + Returns: + graph: The track graph + """ + graph = nx.DiGraph() + + t1 = df.t1.min() + t2 = df.t2.max() + + for t in tqdm(range(t1, t2 + 1)): + obs = df[(df.t1 <= t) & (df.t2 >= t)] + for row in obs.itertuples(): + label, t1, t2, parent = row.label, row.t1, row.t2, row.parent + # add label as node if not already in graph + if not graph.has_node(label): + attrs = {"label": label, frame_attribute: t} + graph.add_node(label, **attrs) + + if parent != 0: + graph.add_edge(parent, label) + + return graph diff --git a/models/tra_post_model/trackastra/utils/__init__.py b/models/tra_post_model/trackastra/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f8b5c70d4216c50e2a81ff9023f87c7aea4ae2ef --- /dev/null +++ b/models/tra_post_model/trackastra/utils/__init__.py @@ -0,0 +1,14 @@ +# ruff: noqa: F401 + +from .utils import ( + blockwise_causal_norm, + blockwise_sum, + normalize, + normalize_01, + preallocate_memory, + random_label_cmap, + render_label, + seed, + str2bool, + str2path, +) diff --git a/models/tra_post_model/trackastra/utils/utils.py b/models/tra_post_model/trackastra/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e5dd8a72fac97824ab2b5fb58cd197f10d80f0bb --- /dev/null +++ b/models/tra_post_model/trackastra/utils/utils.py @@ -0,0 +1,494 @@ +import colorsys +import itertools +import logging +import random +import sys +from pathlib import Path +from timeit import default_timer + +import dask.array as da +import matplotlib +import numpy as np +import torch +from typing import Optional, Union + +logger = logging.getLogger(__name__) + + +def _single_color_integer_cmap(color=(0.3, 0.4, 0.5)): + from matplotlib.colors import Colormap + + assert len(color) in (3, 4) + + class BinaryMap(Colormap): + def __init__(self, color): + self.color = np.array(color) + if len(self.color) == 3: + self.color = np.concatenate([self.color, [1]]) + + def __call__(self, X, alpha=None, bytes=False): + res = np.zeros((*X.shape, 4), np.float32) + res[..., -1] = self.color[-1] + res[X > 0] = np.expand_dims(self.color, 0) + if bytes: + return np.clip(256 * res, 0, 255).astype(np.uint8) + else: + return res + + return BinaryMap(color) + + +def render_label( + lbl, + img=None, + cmap=None, + cmap_img="gray", + alpha=0.5, + alpha_boundary=None, + normalize_img=True, +): + """Renders a label image and optionally overlays it with another image. Used for generating simple output images to asses the label quality. + + Parameters + ---------- + lbl: np.ndarray of dtype np.uint16 + The 2D label image + img: np.ndarray + The array to overlay the label image with (optional) + cmap: string, tuple, or callable + The label colormap. If given as rgb(a) only a single color is used, if None uses a random colormap + cmap_img: string or callable + The colormap of img (optional) + alpha: float + The alpha value of the overlay. Set alpha=1 to get fully opaque labels + alpha_boundary: float + The alpha value of the boundary (if None, use the same as for labels, i.e. no boundaries are visible) + normalize_img: bool + If True, normalizes the img (if given) + + Returns: + ------- + img: np.ndarray + the (m,n,4) RGBA image of the rendered label + + Example: + ------- + from scipy.ndimage import label, zoom + img = zoom(np.random.uniform(0,1,(16,16)),(8,8),order=3) + lbl,_ = label(img>.8) + u1 = render_label(lbl, img = img, alpha = .7) + u2 = render_label(lbl, img = img, alpha = 0, alpha_boundary =.8) + plt.subplot(1,2,1);plt.imshow(u1) + plt.subplot(1,2,2);plt.imshow(u2) + + """ + from matplotlib import cm + from skimage.segmentation import find_boundaries + + alpha = np.clip(alpha, 0, 1) + + if alpha_boundary is None: + alpha_boundary = alpha + + if cmap is None: + cmap = random_label_cmap() + elif isinstance(cmap, tuple): + cmap = _single_color_integer_cmap(cmap) + else: + pass + + cmap = cm.get_cmap(cmap) if isinstance(cmap, str) else cmap + cmap_img = cm.get_cmap(cmap_img) if isinstance(cmap_img, str) else cmap_img + + # render image if given + if img is None: + im_img = np.zeros((*lbl.shape, 4), np.float32) + im_img[..., -1] = 1 + + else: + assert lbl.shape[:2] == img.shape[:2] + img = normalize(img) if normalize_img else img + if img.ndim == 2: + im_img = cmap_img(img) + elif img.ndim == 3: + im_img = img[..., :4] + if img.shape[-1] < 4: + im_img = np.concatenate( + [img, np.ones(img.shape[:2] + (4 - img.shape[-1],))], axis=-1 + ) + else: + raise ValueError("img should be 2 or 3 dimensional") + + # render label + im_lbl = cmap(lbl) + + mask_lbl = lbl > 0 + mask_bound = np.bitwise_and(mask_lbl, find_boundaries(lbl, mode="thick")) + + # blend + im = im_img.copy() + + im[mask_lbl] = alpha * im_lbl[mask_lbl] + (1 - alpha) * im_img[mask_lbl] + im[mask_bound] = ( + alpha_boundary * im_lbl[mask_bound] + (1 - alpha_boundary) * im_img[mask_bound] + ) + + return im + + +def random_label_cmap(n=2**16, h=(0, 1), lightness=(0.4, 1), s=(0.2, 0.8)): + h, lightness, s = ( + np.random.uniform(*h, n), + np.random.uniform(*lightness, n), + np.random.uniform(*s, n), + ) + cols = np.stack( + [colorsys.hls_to_rgb(_h, _l, _s) for _h, _l, _s in zip(h, lightness, s)], axis=0 + ) + cols[0] = 0 + return matplotlib.colors.ListedColormap(cols) + + +# @torch.jit.script +def _blockwise_sum_with_bounds(A: torch.Tensor, bounds: torch.Tensor, dim: int = 0): + A = A.transpose(dim, 0) + cum = torch.cumsum(A, dim=0) + cum = torch.cat((torch.zeros_like(cum[:1]), cum), dim=0) + B = torch.zeros_like(A, device=A.device) + for i, j in itertools.pairwise(bounds[:-1], bounds[1:]): + B[i:j] = cum[j] - cum[i] + B = B.transpose(0, dim) + return B + + +def _bounds_from_timepoints(timepoints: torch.Tensor): + assert timepoints.ndim == 1 + bounds = torch.cat(( + torch.tensor([0], device=timepoints.device), + # torch.nonzero faster than torch.where + torch.nonzero(timepoints[1:] - timepoints[:-1], as_tuple=False)[:, 0] + 1, + torch.tensor([len(timepoints)], device=timepoints.device), + )) + return bounds + + +# def blockwise_sum(A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0): +# # get block boundaries +# assert A.shape[dim] == len(timepoints) + +# bounds = _bounds_from_timepoints(timepoints) + +# # normalize within blocks +# u = _blockwise_sum_with_bounds(A, bounds, dim=dim) +# return u + + +def blockwise_sum( + A: torch.Tensor, timepoints: torch.Tensor, dim: int = 0, reduce: str = "sum" +): + if not A.shape[dim] == len(timepoints): + raise ValueError( + f"Dimension {dim} of A ({A.shape[dim]}) must match length of timepoints" + f" ({len(timepoints)})" + ) + + A = A.transpose(dim, 0) + + if len(timepoints) == 0: + logger.warning("Empty timepoints in block_sum. Returning zero tensor.") + return A + # -1 is the filling value for padded/invalid timepoints + min_t = timepoints[timepoints >= 0] + if len(min_t) == 0: + logger.warning("All timepoints are -1 in block_sum. Returning zero tensor.") + return A + + min_t = min_t.min() + # after that, valid timepoints start with 1 (padding timepoints will be mapped to 0) + ts = torch.clamp(timepoints - min_t + 1, min=0) + index = ts.unsqueeze(1).expand(-1, len(ts)) + blocks = ts.max().long() + 1 + out = torch.zeros((blocks, A.shape[1]), device=A.device, dtype=A.dtype) + out = torch.scatter_reduce(out, 0, index, A, reduce=reduce) + B = out[ts] + B = B.transpose(0, dim) + + return B + + +# TODO allow for batch dimension. Should be faster than looping +def blockwise_causal_norm( + A: torch.Tensor, + timepoints: torch.Tensor, + mode: str = "quiet_softmax", + mask_invalid: torch.BoolTensor = None, + eps: float = 1e-6, +): + """Normalization over the causal dimension of A. + + For each block of constant timepoints, normalize the corresponding block of A + such that the sum over the causal dimension is 1. + + Args: + A (torch.Tensor): input tensor + timepoints (torch.Tensor): timepoints for each element in the causal dimension + mode: normalization mode. + `linear`: Simple linear normalization. + `softmax`: Apply exp to A before normalization. + `quiet_softmax`: Apply exp to A before normalization, and add 1 to the denominator of each row/column. + mask_invalid: Values that should not influence the normalization. + eps (float, optional): epsilon for numerical stability. + """ + assert A.ndim == 2 and A.shape[0] == A.shape[1] + A = A.clone() + + if mode in ("softmax", "quiet_softmax"): + # Subtract max for numerical stability + # https://stats.stackexchange.com/questions/338285/how-does-the-subtraction-of-the-logit-maximum-improve-learning + # TODO test without this subtraction + + if mask_invalid is not None: + assert mask_invalid.shape == A.shape + A[mask_invalid] = -torch.inf + # TODO set to min, then to 0 after exp + + # Blockwise max + with torch.no_grad(): + ma0 = blockwise_sum(A, timepoints, dim=0, reduce="amax") + ma1 = blockwise_sum(A, timepoints, dim=1, reduce="amax") + + u0 = torch.exp(A - ma0) + u1 = torch.exp(A - ma1) + elif mode == "linear": + A = torch.sigmoid(A) + if mask_invalid is not None: + assert mask_invalid.shape == A.shape + A[mask_invalid] = 0 + + u0, u1 = A, A + ma0 = ma1 = 0 + else: + raise NotImplementedError(f"Mode {mode} not implemented") + + # get block boundaries and normalize within blocks + # bounds = _bounds_from_timepoints(timepoints) + # u0_sum = _blockwise_sum_with_bounds(u0, bounds, dim=0) + eps + # u1_sum = _blockwise_sum_with_bounds(u1, bounds, dim=1) + eps + + u0_sum = blockwise_sum(u0, timepoints, dim=0) + eps + u1_sum = blockwise_sum(u1, timepoints, dim=1) + eps + + if mode == "quiet_softmax": + # Add 1 to the denominator of the softmax. With this, the softmax outputs can be all 0, if the logits are all negative. + # If the logits are positive, the softmax outputs will sum to 1. + # Trick: With maximum subtraction, this is equivalent to adding 1 to the denominator + u0_sum += torch.exp(-ma0) + u1_sum += torch.exp(-ma1) + + mask0 = timepoints.unsqueeze(0) > timepoints.unsqueeze(1) + # mask1 = timepoints.unsqueeze(0) < timepoints.unsqueeze(1) + # Entries with t1 == t2 are always masked out in final loss + mask1 = ~mask0 + + # blockwise diagonal will be normalized along dim=0 + res = mask0 * u0 / u0_sum + mask1 * u1 / u1_sum + res = torch.clamp(res, 0, 1) + return res + + +def normalize_tensor(x: torch.Tensor, dim: Optional[int] = None, eps: float = 1e-8): + if dim is None: + dim = tuple(range(x.ndim)) + + mi, ma = torch.amin(x, dim=dim, keepdim=True), torch.amax(x, dim=dim, keepdim=True) + return (x - mi) / (ma - mi + eps) + + +def normalize(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4): + """Percentile normalize the image. + + If subsample is not None, calculate the percentile values over a subsampled image (last two axis) + which is way faster for large images. + """ + x = x.astype(np.float32) + if subsample is not None and all(s > 64 * subsample for s in x.shape[-2:]): + y = x[..., ::subsample, ::subsample] + else: + y = x + + mi, ma = np.percentile(y, (1, 99.8)).astype(np.float32) + x -= mi + x /= ma - mi + 1e-8 + return x + +def normalize_01(x: Union[np.ndarray, da.Array], subsample: Optional[int] = 4): + """Percentile normalize the image. + + If subsample is not None, calculate the percentile values over a subsampled image (last two axis) + which is way faster for large images. + """ + x = x.astype(np.float32) + if subsample is not None and all(s > 64 * subsample for s in x.shape[-2:]): + y = x[..., ::subsample, ::subsample] + else: + y = x + + # mi, ma = np.percentile(y, (1, 99.8)).astype(np.float32) + mi = x.min() + ma = x.max() + x -= mi + x /= ma - mi + 1e-8 + return x + + +def batched(x, batch_size, device): + return x.unsqueeze(0).expand(batch_size, *((-1,) * x.ndim)).to(device) + + +def preallocate_memory(dataset, model_lightning, batch_size, max_tokens, device): + """https://pytorch.org/tutorials/recipes/recipes/tuning_guide.html#preallocate-memory-in-case-of-variable-input-length.""" + start = default_timer() + + if max_tokens is None: + logger.warning( + "Preallocating memory without specifying max_tokens not implemented." + ) + return + + # max_len = 0 + # max_idx = -1 + # # TODO speed up + # # find largest training sample + # if isinstance(dataset, torch.utils.data.dataset.ConcatDataset): + # lens = tuple( + # len(t["timepoints"]) for data in dataset.datasets for t in data.windows + # ) + # elif isinstance(dataset, torch.utils.data.Dataset): + # lens = tuple(len(t["timepoints"]) for t in dataset.windows) + # else: + # lens = tuple( + # len(s["timepoints"]) + # for i, s in tqdm( + # enumerate(dataset), + # desc="Iterate over training set to find largest training sample", + # total=len(dataset), + # leave=False, + # ) + # ) + + # max_len = max(lens) + # max_idx = lens.index(max_len) + + # # build random batch + # x = dataset[max_idx] + # batch = dict( + # features=batched(x["features"], batch_size, device), + # coords=batched(x["coords"], batch_size, device), + # assoc_matrix=batched(x["assoc_matrix"], batch_size, device), + # timepoints=batched(x["timepoints"], batch_size, device), + # padding_mask=batched(torch.zeros_like(x["timepoints"]), batch_size, device), + # ) + + else: + max_len = max_tokens + x = dataset[0] + batch = dict( + features=batched( + torch.zeros( + (max_len,) + x["features"].shape[1:], dtype=x["features"].dtype + ), + batch_size, + device, + ), + coords=batched( + torch.zeros( + (max_len,) + x["coords"].shape[1:], dtype=x["coords"].dtype + ), + batch_size, + device, + ), + assoc_matrix=batched( + torch.zeros((max_len, max_len), dtype=x["assoc_matrix"].dtype), + batch_size, + device, + ), + timepoints=batched( + torch.zeros(max_len, dtype=x["timepoints"].dtype), batch_size, device + ), + padding_mask=batched(torch.zeros(max_len, dtype=bool), batch_size, device), + ) + + loss = model_lightning._common_step(batch)["loss"] + loss.backward() + model_lightning.zero_grad() + + logger.info( + f"Preallocated memory for largest training batch (length {max_len}) in" + f" {default_timer() - start:.02f} s" + ) + if device.type == "cuda": + logger.info( + "Memory allocated for model:" + f" {torch.cuda.max_memory_allocated() / 1024**3:.02f} GB" + ) + + +def seed(s=None): + """Seed random number generators. + + Defaults to unix timestamp of function call. + + Args: + s (``int``): Manual seed. + """ + if s is None: + s = int(default_timer()) + + random.seed(s) + logger.debug(f"Seed `random` rng with {s}.") + np.random.seed(s) + logger.debug(f"Seed `numpy` rng with {s}.") + if "torch" in sys.modules: + torch.manual_seed(s) + logger.debug(f"Seed `torch` rng with {s}.") + + return s + + +def str2bool(x: str) -> bool: + """Cast string to boolean. + + Useful for parsing command line arguments. + """ + if not isinstance(x, str): + raise TypeError("String expected.") + elif x.lower() in ("true", "t", "1"): + return True + elif x.lower() in ("false", "f", "0"): + return False + else: + raise ValueError(f"'{x}' does not seem to be boolean.") + + +def str2path(x: str) -> Path: + """Cast string to resolved absolute path. + + Useful for parsing command line arguments. + """ + if not isinstance(x, str): + raise TypeError("String expected.") + else: + return Path(x).expanduser().resolve() + + +if __name__ == "__main__": + A = torch.rand(50, 50) + idx = torch.tensor([0, 10, 20, A.shape[0]]) + + A = torch.eye(50) + + B = _blockwise_sum_with_bounds(A, idx) + + tps = torch.repeat_interleave(torch.arange(5), 10) + + C = blockwise_causal_norm(A, tps) diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..700ed8788bda87d11cd5c7ca224d2e53254735db --- /dev/null +++ b/requirements.txt @@ -0,0 +1,52 @@ +# PyTorch 2.4.1 + torchvision +torch==2.4.1 +torchvision==0.19.1 +torchaudio==2.4.1 + +# Core dependencies +diffusers==0.29.0 +transformers==4.41.0 # ✅ This version supports huggingface-hub 1.x +accelerate==0.23.0 +pyrallis +easydict +omegaconf==2.1.1 +einops==0.3.0 +torch-fidelity==0.3.0 +torchmetrics==1.3.0 +pytorch-lightning==2.2.0 +taming-transformers @ git+https://github.com/CompVis/taming-transformers.git@master +clip @ git+https://github.com/OpenAI/CLIP.git@main + +# Computer vision +opencv-python==4.5.5.64 +opencv-python-headless==4.5.5.64 +kornia==0.7.1 +albumentations==0.4.3 +imageio>=2.27 +imageio-ffmpeg==0.4.2 +matplotlib +scikit-image==0.21.0 +Pillow +segment-anything +numpy==1.24.4 + +# Gradio +gradio +gradio-bbox-annotator + +# Utilities +natsort +roifile +fill-voids +configargparse +ipywidgets +ftfy +sniffio +websocket-client +dask +tensorboard +joblib +lz4 +numba +edt +imagecodecs \ No newline at end of file diff --git a/segmentation.py b/segmentation.py new file mode 100644 index 0000000000000000000000000000000000000000..f5221188f4ccac2b20a1d9f350a71fac645924dd --- /dev/null +++ b/segmentation.py @@ -0,0 +1,380 @@ +import os +import pprint +from typing import Any, List, Optional +import argparse +from huggingface_hub import hf_hub_download +import pyrallis +from pytorch_lightning.utilities.types import STEP_OUTPUT +import torch +import os +from PIL import Image +import numpy as np +from config import RunConfig +from _utils import attn_utils_new as attn_utils +from _utils.attn_utils import AttentionStore +from _utils.misc_helper import * +import torch.nn.functional as F +import logging +import matplotlib.pyplot as plt +import matplotlib.patches as patches +import cv2 +import warnings +warnings.filterwarnings("ignore", category=UserWarning) +import pytorch_lightning as pl +from _utils.load_models import load_stable_diffusion_model +from models.model import Counting_with_SD_features_dino_vit_c3 as Counting +from models.enc_model.loca_args import get_argparser as loca_get_argparser +from models.enc_model.loca import build_model as build_loca_model +import time +from _utils.seg_eval import * +from models.seg_post_model.cellpose import metrics +from datetime import datetime +import json +import logging +from PIL import Image +import torchvision.transforms as T +import cv2 +from skimage import io, measure +logging.getLogger('models.seg_post_model.cellpose.models').setLevel(logging.ERROR) + +SCALE = 1 + + + +class SegmentationModule(pl.LightningModule): + def __init__(self, use_box=True): + super().__init__() + self.use_box = use_box + self.config = RunConfig() # config for stable diffusion + self.initialize_model() + + + def initialize_model(self): + + # load loca model + loca_args = loca_get_argparser().parse_args() + self.loca_model = build_loca_model(loca_args) + self.loca_model.eval() + + self.counting_adapter = Counting(scale_factor=SCALE) + + ### load stable diffusion and its controller + self.stable = load_stable_diffusion_model(config=self.config) + self.noise_scheduler = self.stable.scheduler + self.controller = AttentionStore(max_size=64) + attn_utils.register_attention_control(self.stable, self.controller) + attn_utils.register_hier_output(self.stable) + + ##### initialize token_emb ##### + placeholder_token = "" + self.task_token = "repetitive objects" + # Add the placeholder token in tokenizer + num_added_tokens = self.stable.tokenizer.add_tokens(placeholder_token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + try: + # print("loading pretrained task embedding from {}".format("pretrained/task_embed.pth")) + # task_embed_from_pretrain = torch.load("pretrained/task_embed.pth") + task_embed_from_pretrain = hf_hub_download( + repo_id="phoebe777777/111", + filename="task_embed.pth", + token=None, + force_download=False + ) + placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token) + self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) + + token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data + token_embeds[placeholder_token_id] = task_embed_from_pretrain + except: + initializer_token = "segment" + token_ids = self.stable.tokenizer.encode(initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + # raise ValueError("The initializer token must be a single token.") + token_ids = token_ids[:1] + + initializer_token_id = token_ids[0] + placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token) + + self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) + + token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data + token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] + + # others + self.placeholder_token = placeholder_token + self.placeholder_token_id = placeholder_token_id + + + + + def move_to_device(self, device): + self.stable.to(device) + self.counting_adapter.to(device) + self.loca_model.to(device) + + self.to(device) + + + def forward(self, data_path, box=None): + filename = data_path.split("/")[-1] + img = Image.open(data_path).convert("RGB") + width, height = img.size + input_image = T.Compose([T.ToTensor(), T.Resize((512, 512))])(img) + input_image_stable = input_image - 0.5 + input_image = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_image) + if box is not None: + boxes = torch.tensor(box) / torch.tensor([width, height, width, height]) * 512 # xyxy, normalized + assert self.use_box == True + else: + boxes = torch.tensor([[0,0,512,512]]) + assert self.use_box == False + img_raw = io.imread(data_path) + if len(img_raw.shape) == 3 and img_raw.shape[2] > 3: + img_raw = img_raw[:,:,:3] + img_raw = cv2.resize(img_raw, (512, 512)) + + # move to device + input_image = input_image.unsqueeze(0).to(self.device) + img_raw = torch.from_numpy(img_raw).unsqueeze(0).float().to(self.device) + boxes = boxes.unsqueeze(0).to(self.device) + input_image_stable = input_image_stable.unsqueeze(0).to(self.device) + + latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() + latents = latents * 0.18215 + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = torch.tensor([20], device=latents.device).long() + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + input_ids_ = self.stable.tokenizer( + self.placeholder_token + " " + self.task_token, + padding="max_length", + truncation=True, + max_length=self.stable.tokenizer.model_max_length, + return_tensors="pt", + ) + input_ids = input_ids_["input_ids"].to(self.device) + attention_mask = input_ids_["attention_mask"].to(self.device) + encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] + encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) + + + + task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) + + if self.use_box: + loca_out = self.loca_model.forward_before_reg(input_image, boxes) + loca_feature_bf_regression = loca_out["feature_bf_regression"] + adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768] + # adapted_emb = self.counting_adapter.adapter(data['crops_dino'], self.dino) # shape [1, 768] + if task_loc_idx.shape[0] == 0: + encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位 + else: + encoder_hidden_states[:,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位 + + # Predict the noise residual + noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) + time3 = time.time() + noise_pred = noise_pred.sample + + attention_store = self.controller.attention_store + + + attention_maps = [] + exemplar_attention_maps1 = [] + exemplar_attention_maps2 = [] + exemplar_attention_maps3 = [] + + cross_self_task_attn_maps = [] + cross_self_exe_attn_maps1 = [] + cross_self_exe_attn_maps2 = [] + cross_self_exe_attn_maps3 = [] + + # only use 64x64 self-attention + self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096] + prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 + attention_store=self.controller, + res=64, + from_where=("up", "down"), + is_cross=False, + select=0 + ) + + # cross attention + for res in [32, 16]: + attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] + prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 + attention_store=self.controller, + res=res, + from_where=("up", "down"), + is_cross=True, + select=0 + ) + + task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] + attention_maps.append(task_attn_) + exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn + exemplar_attention_maps1.append(exemplar_attns1) + exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0) # 取exemplar的attn + exemplar_attention_maps2.append(exemplar_attns2) + exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0) # 取exemplar的attn + exemplar_attention_maps3.append(exemplar_attns3) + + + scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] + attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) + task_attn_64 = torch.mean(attns, dim=0, keepdim=True) + cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64) + task_attn_64 = (task_attn_64 - task_attn_64.min()) / (task_attn_64.max() - task_attn_64.min() + 1e-6) + cross_self_task_attn = (cross_self_task_attn - cross_self_task_attn.min()) / (cross_self_task_attn.max() - cross_self_task_attn.min() + 1e-6) + + scale_factors = [(64 // exemplar_attention_maps1[i].shape[-1]) for i in range(len(exemplar_attention_maps1))] + attns = torch.cat([F.interpolate(exemplar_attention_maps1[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps1))]) + exemplar_attn_64_1 = torch.mean(attns, dim=0, keepdim=True) + + if self.use_box: + exemplar_attn_64 = exemplar_attn_64_1 + cross_self_exe_attn = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64) + exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6) + cross_self_exe_attn = (cross_self_exe_attn - cross_self_exe_attn.min()) / (cross_self_exe_attn.max() - cross_self_exe_attn.min() + 1e-6) + else: + + scale_factors = [(64 // exemplar_attention_maps2[i].shape[-1]) for i in range(len(exemplar_attention_maps2))] + attns = torch.cat([F.interpolate(exemplar_attention_maps2[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps2))]) + exemplar_attn_64_2 = torch.mean(attns, dim=0, keepdim=True) + + scale_factors = [(64 // exemplar_attention_maps3[i].shape[-1]) for i in range(len(exemplar_attention_maps3))] + attns = torch.cat([F.interpolate(exemplar_attention_maps3[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps3))]) + exemplar_attn_64_3 = torch.mean(attns, dim=0, keepdim=True) + + cross_self_exe_attn1 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_1) + cross_self_exe_attn2 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_2) + cross_self_exe_attn3 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_3) + # # average + exemplar_attn_64_1 = (exemplar_attn_64_1 - exemplar_attn_64_1.min()) / (exemplar_attn_64_1.max() - exemplar_attn_64_1.min() + 1e-6) + exemplar_attn_64_2 = (exemplar_attn_64_2 - exemplar_attn_64_2.min()) / (exemplar_attn_64_2.max() - exemplar_attn_64_2.min() + 1e-6) + exemplar_attn_64_3 = (exemplar_attn_64_3 - exemplar_attn_64_3.min()) / (exemplar_attn_64_3.max() - exemplar_attn_64_3.min() + 1e-6) + cross_self_exe_attn1 = (cross_self_exe_attn1 - cross_self_exe_attn1.min()) / (cross_self_exe_attn1.max() - cross_self_exe_attn1.min() + 1e-6) + cross_self_exe_attn2 = (cross_self_exe_attn2 - cross_self_exe_attn2.min()) / (cross_self_exe_attn2.max() - cross_self_exe_attn2.min() + 1e-6) + cross_self_exe_attn3 = (cross_self_exe_attn3 - cross_self_exe_attn3.min()) / (cross_self_exe_attn3.max() - cross_self_exe_attn3.min() + 1e-6) + + exemplar_attn_64 = (exemplar_attn_64_1 + exemplar_attn_64_2 + exemplar_attn_64_3) / 3 + cross_self_exe_attn = (cross_self_exe_attn1 + cross_self_exe_attn2 + cross_self_exe_attn3) / 3 + + + + + + if self.use_box: + attn_stack = [task_attn_64 / 2, cross_self_task_attn / 2, exemplar_attn_64, cross_self_exe_attn] + else: + attn_stack = [exemplar_attn_64 / 2, cross_self_exe_attn / 2, exemplar_attn_64, cross_self_exe_attn] + attn_stack = torch.cat(attn_stack, dim=1) + + + attn_after_new_regressor = self.counting_adapter.regressor(img_raw, attn_stack, feature_list) # 直接用自己的 + + input_image = cv2.resize(input_image[0].permute(1,2,0).cpu().numpy(), (width, height)) + pred = cv2.resize(attn_after_new_regressor.squeeze().cpu().numpy(), (width, height), interpolation=cv2.INTER_NEAREST) + return pred + + + + + +def inference(data_path, box=None, save_path="./example_imgs", visualize=False): + if box is not None: + use_box = True + else: + use_box = False + model = SegmentationModule(use_box=use_box) + load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_seg.pth"), strict=True) + model.eval() + with torch.no_grad(): + mask = model(data_path, box) + + + # visualize + if visualize: + img = io.imread(data_path) + if len(img.shape) == 3 and img.shape[2] > 3: + img = img[:,:,:3] + if len(img.shape) == 2: + img = np.stack([img]*3, axis=-1) + img_show = img.squeeze() + mask_show = mask.squeeze() + os.makedirs(save_path, exist_ok=True) + filename = data_path.split("/")[-1] + fig, ax = plt.subplots(1,2, figsize=(12,6)) + ax[0].imshow(img_show) + if use_box: + boxes = np.array(box) + for box in boxes: + rect = patches.Rectangle((box[0], box[1]), box[2]-box[0], box[3]-box[1], linewidth=2, edgecolor='r', facecolor='none') + ax[0].add_patch(rect) + ax[0].set_title("Input Image with Box") + else: + ax[0].set_title("Input Image") + ax[0].axis("off") + ax[1].imshow(img_show) + for inst_id in np.unique(mask_show): + if inst_id == 0: # 0 通常是背景 + continue + # 生成二值 mask + binary_mask = (mask_show == inst_id).astype(np.uint8) + contours = measure.find_contours(binary_mask, 0.5) + for contour in contours: + ax[1].plot(contour[:, 1], contour[:, 0], linewidth=1.5, linestyle="--", color='yellow') + ax[1].imshow(overlay_instances(img_show, mask_show, alpha=0.3)) + ax[1].set_title("Segmentation Result") + ax[1].axis("off") + plt.tight_layout() + plt.savefig(os.path.join(save_path, filename.split(".")[0]+"_seg.png"), dpi=300) + plt.close() + + return mask + + +def main(): + inference( + data_path="example_imgs/1977_Well_F-5_Field_1.png", + # box=[[724, 864, 900, 966]], + save_path="./example_imgs", + visualize=True + ) + + +from matplotlib import cm + +def overlay_instances(img, mask, alpha=0.5, cmap_name="tab20"): + """ + img: 原图 (H, W, 3),范围 [0,255] 或 [0,1] + mask: 实例分割结果 (H, W),背景=0,实例=1,2,... + alpha: 透明度 + cmap_name: 颜色映射表 + """ + img = img.astype(np.float32) + if len(img.shape) == 2: + img = np.stack([img]*3, axis=-1) + if img.max() > 1.5: + img = img / 255.0 + + + overlay = img.copy() + cmap = cm.get_cmap(cmap_name, np.max(mask)+1) + + for inst_id in np.unique(mask): + if inst_id == 0: # 背景跳过 + continue + color = np.array(cmap(inst_id)[:3]) # RGB + overlay[mask == inst_id] = (1 - alpha) * overlay[mask == inst_id] + alpha * color + + return overlay + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/tracking_one.py b/tracking_one.py new file mode 100644 index 0000000000000000000000000000000000000000..eee8d9686a2d3a8da87b32d943a3f097d51eaf16 --- /dev/null +++ b/tracking_one.py @@ -0,0 +1,1068 @@ +import os +import pprint +from typing import Any, List, Optional +import argparse +from huggingface_hub import hf_hub_download +import pyrallis +from pytorch_lightning.utilities.types import STEP_OUTPUT +import torch +import os +from PIL import Image +import numpy as np +import tifffile +import skimage.io as io +from config import RunConfig +from _utils import attn_utils_new as attn_utils +from _utils.attn_utils import AttentionStore +from _utils.misc_helper import * +from torch.autograd import Variable +import itertools +from accelerate import Accelerator +import torch.nn.functional as F +from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline +from tqdm import tqdm +import torch.nn as nn +import matplotlib.pyplot as plt +import cv2 +import warnings +warnings.filterwarnings("ignore", category=UserWarning) +import pytorch_lightning as pl +from _utils.load_models import load_stable_diffusion_model +from models.model import Counting_with_SD_features_track as Counting +from models.enc_model.loca_args import get_argparser as loca_get_argparser +from models.enc_model.loca import build_model as build_loca_model +import time +from _utils.seg_eval import * +from models.tra_post_model.trackastra.model import Trackastra +from models.tra_post_model.trackastra.model import TrackingTransformer +from models.tra_post_model.trackastra.utils import ( + blockwise_causal_norm, + blockwise_sum, + normalize, +) +from models.tra_post_model.trackastra.data import build_windows_sd, get_features, load_tiff_timeseries +from models.tra_post_model.trackastra.tracking import TrackGraph, build_graph, track_greedy, graph_to_ctc +from _utils.track_args import parse_train_args as get_track_args +import torchvision.transforms as T +from pathlib import Path +import dask.array as da +from typing import Dict, List, Optional, Union, Literal +from scipy.sparse import SparseEfficiencyWarning, csr_array +import tracemalloc +import gc +# from memory_profiler import profile +from _utils.load_track_data import load_track_images + +SCALE = 1 + +def get_instance_boxes(mask): + # Convert to int64 if needed + if mask.dtype != torch.long: + mask = mask.to(torch.long) + + boxes = [] + instance_ids = torch.unique(mask) + instance_ids = instance_ids[instance_ids != 0] # skip background + + for inst_id in instance_ids: + inst_mask = mask == inst_id + y_indices, x_indices = torch.where(inst_mask) + + if len(x_indices) == 0 or len(y_indices) == 0: + continue + + x_min = torch.min(x_indices).item() + x_max = torch.max(x_indices).item() + y_min = torch.min(y_indices).item() + y_max = torch.max(y_indices).item() + + boxes.append([x_min, y_min, x_max, y_max]) + boxes = torch.tensor(boxes, dtype=torch.float32) + return boxes + +class TrackingModule(pl.LightningModule): + def __init__(self, use_box=False): + super().__init__() + self.use_box = use_box + self.config = RunConfig() # config for stable diffusion + self.initialize_model() + + def initialize_model(self): + + # load loca model + loca_args = loca_get_argparser().parse_args() + self.loca_model = build_loca_model(loca_args) + # weights = torch.load("ckpt/loca_few_shot.pt")["model"] + # weights = {k.replace("module","") : v for k, v in weights.items()} + # self.loca_model.load_state_dict(weights, strict=False) + # del weights + + self.counting_adapter = Counting(scale_factor=SCALE) + # if os.path.isfile(self.args.adapter_weight): + # adapter_weight = torch.load(self.args.adapter_weight,map_location=torch.device('cpu')) + # self.counting_adapter.load_state_dict(adapter_weight, strict=False) + + ### load stable diffusion and its controller + self.stable = load_stable_diffusion_model(config=self.config) + self.noise_scheduler = self.stable.scheduler + self.controller = AttentionStore(max_size=64) + attn_utils.register_attention_control(self.stable, self.controller) + attn_utils.register_hier_output(self.stable) + + ##### initialize token_emb ##### + placeholder_token = "" + self.task_token = "repetitive objects" + # Add the placeholder token in tokenizer + num_added_tokens = self.stable.tokenizer.add_tokens(placeholder_token) + if num_added_tokens == 0: + raise ValueError( + f"The tokenizer already contains the token {placeholder_token}. Please pass a different" + " `placeholder_token` that is not already in the tokenizer." + ) + try: + # task_embed_from_pretrain = torch.load("pretrained/task_embed.pth") + task_embed_from_pretrain = hf_hub_download( + repo_id="phoebe777777/111", + filename="task_embed.pth", + token=None, + force_download=False + ) + placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token) + self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) + + token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data + token_embeds[placeholder_token_id] = task_embed_from_pretrain + except: + #@title Get token ids for our placeholder and initializer token. This code block will complain if initializer string is not a single token + # Convert the initializer_token, placeholder_token to ids + initializer_token = "track" + token_ids = self.stable.tokenizer.encode(initializer_token, add_special_tokens=False) + # Check if initializer_token is a single token or a sequence of tokens + if len(token_ids) > 1: + raise ValueError("The initializer token must be a single token.") + + initializer_token_id = token_ids[0] + placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token) + + self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer)) + + token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data + token_embeds[placeholder_token_id] = token_embeds[initializer_token_id] + + # others + self.placeholder_token = placeholder_token + self.placeholder_token_id = placeholder_token_id + + # tracking model + # fpath = Path("models/tra_post_model/trackastra/.models/general_2d/model.pt") + fpath = Path("_utils/config.yaml") + args_ = get_track_args() + + model = TrackingTransformer.from_cfg( + cfg_path=fpath, + args=args_, + ) + # model = TrackingTransformer.from_folder( + # Path(*fpath.parts[:-1]), + # args=args_, + # checkpoint_path=Path(*fpath.parts[-1:]), + # ) + + + self.track_model = model + self.track_args = args_ + + + def move_to_device(self, device): + self.stable.to(device) + # if self.loca_model is not None and self.counting_adapter is not None: + # self.loca_model.to(device) + # self.counting_adapter.to(device) + self.counting_adapter.to(device) + # self.dino.to(device) + self.loca_model.to(device) + self.track_model.to(device) + + self.to(device) + + def on_train_start(self) -> None: + device = self.device + dtype = self.dtype + self.stable.to(device,dtype) + + def on_validation_start(self) -> None: + device = self.device + dtype = self.dtype + self.stable.to(device,dtype) + + def forward(self, data): + + input_image_stable = data["image_stable"] + boxes = data["boxes"] + input_image = data["img_enc"] + mask = data["mask"] + latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() + latents = latents * 0.18215 + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = torch.tensor([20], device=latents.device).long() + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + input_ids_ = self.stable.tokenizer( + self.placeholder_token, + # "object", + padding="max_length", + truncation=True, + max_length=self.stable.tokenizer.model_max_length, + return_tensors="pt", + ) + input_ids = input_ids_["input_ids"].to(self.device) + attention_mask = input_ids_["attention_mask"].to(self.device) + encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] + encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) + + time1 = time.time() + input_image = input_image.to(self.device) + boxes = boxes.to(self.device) + + loca_out = self.loca_model.forward_before_reg(input_image, boxes) + loca_feature_bf_regression = loca_out["feature_bf_regression"] + # time2 = time.time() + + task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) + adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768] + + if task_loc_idx.shape[0] == 0: + encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位 + else: + encoder_hidden_states[:,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位 + + # Predict the noise residual + noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) + time3 = time.time() + noise_pred = noise_pred.sample + + attention_store = self.controller.attention_store + + # print(time2-time1, time3-time2) + + attention_maps = [] + exemplar_attention_maps = [] + + cross_self_task_attn_maps = [] + cross_self_exe_attn_maps = [] + + # only use 64x64 self-attention + self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096] + prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 + attention_store=self.controller, + res=64, + from_where=("up", "down"), + is_cross=False, + select=0 + ) + self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096] + prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 + attention_store=self.controller, + res=32, + from_where=("up", "down"), + is_cross=False, + select=0 + ) + self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096] + prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 + attention_store=self.controller, + res=16, + from_where=("up", "down"), + is_cross=False, + select=0 + ) + + # cross attention + for res in [32, 16]: + attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] + prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 + attention_store=self.controller, + res=res, + from_where=("up", "down"), + is_cross=True, + select=0 + ) + + task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] + attention_maps.append(task_attn_) + exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn + exemplar_attention_maps.append(exemplar_attns) + + + scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] + attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) + task_attn_64 = torch.mean(attns, dim=0, keepdim=True) + + + scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))] + attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))]) + exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True) + + cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64) + cross_self_exe_attn = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64) + cross_self_task_attn_maps.append(cross_self_task_attn) + cross_self_exe_attn_maps.append(cross_self_exe_attn) + + task_attn_64 = (task_attn_64 - task_attn_64.min()) / (task_attn_64.max() - task_attn_64.min() + 1e-6) + cross_self_task_attn = (cross_self_task_attn - cross_self_task_attn.min()) / (cross_self_task_attn.max() - cross_self_task_attn.min() + 1e-6) + exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6) + cross_self_exe_attn = (cross_self_exe_attn - cross_self_exe_attn.min()) / (cross_self_exe_attn.max() - cross_self_exe_attn.min() + 1e-6) + + attn_stack = [task_attn_64 / 2, cross_self_task_attn / 2, exemplar_attn_64, cross_self_exe_attn] + attn_stack = torch.cat(attn_stack, dim=1) + + + attn_after_new_regressor, loss = self.counting_adapter.regressor(input_image, attn_stack, feature_list, mask.cpu().numpy(), training=False) # 直接用自己的 + + return { + "attn_after_new_regressor":attn_after_new_regressor, + "task_attn_64":task_attn_64, + "cross_self_task_attn":cross_self_task_attn, + "exemplar_attn_64": exemplar_attn_64, + "cross_self_exe_attn": cross_self_exe_attn, + "noise_pred":noise_pred, + "noise":noise, + "self_attn_aggregate":self_attn_aggregate, + "self_attn_aggregate32":self_attn_aggregate32, + "self_attn_aggregate16":self_attn_aggregate16, + "loss": loss + } + + def forward_sd(self, input_image_stable, input_image, boxes, height, width, mask=None): + + input_image_stable = input_image_stable.to(self.device) + # density = data["density"] + if boxes is not None: + boxes = boxes.to(self.device) + input_image = input_image.to(self.device) + if mask is not None: + mask = mask.to(self.device) + else: + mask = torch.zeros((input_image.shape[0], 1, input_image.shape[2], input_image.shape[3])).to(self.device) + + latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() + latents = latents * 0.18215 + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = torch.tensor([20], device=latents.device).long() + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + input_ids_ = self.stable.tokenizer( + self.placeholder_token + " " + self.task_token, + # "object", + padding="max_length", + truncation=True, + max_length=self.stable.tokenizer.model_max_length, + return_tensors="pt", + ) + input_ids = input_ids_["input_ids"].to(self.device) + attention_mask = input_ids_["attention_mask"].to(self.device) + encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] + encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) + + + if boxes is not None and not self.training: + if self.adapt_emb is None: + loca_out_ = self.loca_model.forward_before_reg(input_image, boxes) + loca_feature_bf_regression_ = loca_out_["feature_bf_regression"] + adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression_, boxes) # shape [1, 768] + else: + adapted_emb = self.adapt_emb.to(self.device) + task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) + if task_loc_idx.shape[0] == 0: + encoder_hidden_states[0,5,:] = adapted_emb.squeeze() # 放在task prompt下一位 + else: + encoder_hidden_states[:,task_loc_idx[0, 1]+4,:] = adapted_emb.squeeze() # 放在task prompt下一位 + + # Predict the noise residual + noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) + noise_pred = noise_pred.sample + attention_store = self.controller.attention_store + + + attention_maps = [] + exemplar_attention_maps = [] + exemplar_attention_maps1 = [] + exemplar_attention_maps2 = [] + exemplar_attention_maps3 = [] + exemplar_attention_maps4 = [] + + cross_self_task_attn_maps = [] + cross_self_exe_attn_maps = [] + + # only use 64x64 self-attention + self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096] + prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 + attention_store=self.controller, + res=64, + from_where=("up", "down"), + is_cross=False, + select=0 + ) + + # cross attention + for res in [32, 16]: + attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] + prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 + attention_store=self.controller, + res=res, + from_where=("up", "down"), + is_cross=True, + select=0 + ) + + task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] + attention_maps.append(task_attn_) + # if self.boxes is not None and not self.training: + exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn + exemplar_attention_maps1.append(exemplar_attns1) + exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0) # 取exemplar的attn + exemplar_attention_maps2.append(exemplar_attns2) + exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0) # 取exemplar的attn + exemplar_attention_maps3.append(exemplar_attns3) + exemplar_attns4 = attn_aggregate[:, :, 5].unsqueeze(0).unsqueeze(0) # 取exemplar的attn + exemplar_attention_maps4.append(exemplar_attns4) + + + + scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] + attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) + task_attn_64 = torch.mean(attns, dim=0, keepdim=True) + cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64) + cross_self_task_attn_maps.append(cross_self_task_attn) + + # if not self.training: + scale_factors = [(64 // exemplar_attention_maps1[i].shape[-1]) for i in range(len(exemplar_attention_maps1))] + attns = torch.cat([F.interpolate(exemplar_attention_maps1[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps1))]) + exemplar_attn_64_1 = torch.mean(attns, dim=0, keepdim=True) + + scale_factors = [(64 // exemplar_attention_maps2[i].shape[-1]) for i in range(len(exemplar_attention_maps2))] + attns = torch.cat([F.interpolate(exemplar_attention_maps2[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps2))]) + exemplar_attn_64_2 = torch.mean(attns, dim=0, keepdim=True) + + scale_factors = [(64 // exemplar_attention_maps3[i].shape[-1]) for i in range(len(exemplar_attention_maps3))] + attns = torch.cat([F.interpolate(exemplar_attention_maps3[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps3))]) + exemplar_attn_64_3 = torch.mean(attns, dim=0, keepdim=True) + + if boxes is not None: + scale_factors = [(64 // exemplar_attention_maps4[i].shape[-1]) for i in range(len(exemplar_attention_maps4))] + attns = torch.cat([F.interpolate(exemplar_attention_maps4[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps4))]) + exemplar_attn_64_4 = torch.mean(attns, dim=0, keepdim=True) + + exes = [] + cross_exes = [] + cross_self_exe_attn1 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_1) + cross_self_exe_attn2 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_2) + cross_self_exe_attn3 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_3) + + # # average + exemplar_attn_64_1 = (exemplar_attn_64_1 - exemplar_attn_64_1.min()) / (exemplar_attn_64_1.max() - exemplar_attn_64_1.min() + 1e-6) + exemplar_attn_64_2 = (exemplar_attn_64_2 - exemplar_attn_64_2.min()) / (exemplar_attn_64_2.max() - exemplar_attn_64_2.min() + 1e-6) + exemplar_attn_64_3 = (exemplar_attn_64_3 - exemplar_attn_64_3.min()) / (exemplar_attn_64_3.max() - exemplar_attn_64_3.min() + 1e-6) + cross_self_exe_attn1 = (cross_self_exe_attn1 - cross_self_exe_attn1.min()) / (cross_self_exe_attn1.max() - cross_self_exe_attn1.min() + 1e-6) + cross_self_exe_attn2 = (cross_self_exe_attn2 - cross_self_exe_attn2.min()) / (cross_self_exe_attn2.max() - cross_self_exe_attn2.min() + 1e-6) + cross_self_exe_attn3 = (cross_self_exe_attn3 - cross_self_exe_attn3.min()) / (cross_self_exe_attn3.max() - cross_self_exe_attn3.min() + 1e-6) + exes = [exemplar_attn_64_1, exemplar_attn_64_2, exemplar_attn_64_3] + cross_exes = [cross_self_exe_attn1, cross_self_exe_attn2, cross_self_exe_attn3] + if boxes is not None: + cross_self_exe_attn4 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_4) + exemplar_attn_64_4 = (exemplar_attn_64_4 - exemplar_attn_64_4.min()) / (exemplar_attn_64_4.max() - exemplar_attn_64_4.min() + 1e-6) + cross_self_exe_attn4 = (cross_self_exe_attn4 - cross_self_exe_attn4.min()) / (cross_self_exe_attn4.max() - cross_self_exe_attn4.min() + 1e-6) + exes.append(exemplar_attn_64_4) + cross_exes.append(cross_self_exe_attn4) + exemplar_attn_64 = sum(exes) / len(exes) + cross_self_exe_attn = sum(cross_exes) / len(cross_exes) + + + + if self.use_box: + attn_stack = [task_attn_64 / 2, cross_self_task_attn / 2, exemplar_attn_64, cross_self_exe_attn] + else: + attn_stack = [exemplar_attn_64 / 2, cross_self_exe_attn / 2, exemplar_attn_64, cross_self_exe_attn] + attn_stack = torch.cat(attn_stack, dim=1) + + + attn_after_new_regressor, loss, _ = self.counting_adapter.regressor.forward_seg(input_image, attn_stack, feature_list, mask.cpu().numpy(), self.training) + + if not self.training: + pred_mask = attn_after_new_regressor.detach().cpu() + pred_boxes = get_instance_boxes(pred_mask.squeeze()) + + self.boxes = pred_boxes.unsqueeze(0) + + if pred_boxes.shape[0] == 0: + print("No instances detected in the predicted mask.") + self.adapt_emb = adapted_emb.detach().cpu() # reuse emb + else: + pred_boxes = pred_boxes.unsqueeze(0).to(self.device) + loca_out_ = self.loca_model.forward_before_reg(input_image, pred_boxes) + loca_feature_bf_regression_ = loca_out_["feature_bf_regression"] + adapted_emb_ = self.counting_adapter.adapter(loca_feature_bf_regression_, pred_boxes) # shape [1, 768] + self.adapt_emb = adapted_emb_.detach().cpu() + + # resize to original image size + mask_np = attn_after_new_regressor.squeeze().detach().cpu().numpy() + mask_resized = cv2.resize(mask_np, (width, height), interpolation=cv2.INTER_NEAREST) + + return mask_resized + + def forward_boxes(self, input_image_stable, boxes, input_image): + + latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach() + latents = latents * 0.18215 + # Sample noise that we'll add to the latents + noise = torch.randn_like(latents) + bsz = latents.shape[0] + timesteps = torch.tensor([20], device=latents.device).long() + noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps) + input_ids_ = self.stable.tokenizer( + self.placeholder_token, + # "object", + padding="max_length", + truncation=True, + max_length=self.stable.tokenizer.model_max_length, + return_tensors="pt", + ) + input_ids = input_ids_["input_ids"].to(self.device) + attention_mask = input_ids_["attention_mask"].to(self.device) + encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0] + encoder_hidden_states = encoder_hidden_states.repeat(bsz, 1, 1) + + time1 = time.time() + input_image = input_image.to(self.device) + boxes = boxes.to(self.device) + + loca_out = self.loca_model.forward_before_reg(input_image, boxes) + loca_feature_bf_regression = loca_out["feature_bf_regression"] + # time2 = time.time() + + task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id) + adapted_emb = self.counting_adapter.adapter.forward_boxes(loca_feature_bf_regression, boxes) # shape [n_instance, 768] + n_instance = adapted_emb.shape[0] + n_forward = int(np.ceil(n_instance / 74)) # in total 75 prompts including 1 task prompt and 74 object prompts? + + task_cross_attention = [] + instances_cross_attention = [] + + for n in range(n_forward): + len_ = min(74, n_instance - n * 74) + encoder_hidden_states[:,(task_loc_idx[0, 1]+1):(task_loc_idx[0, 1]+1+len_),:] = adapted_emb[n*74:n*74+len_].squeeze() # 放在task prompt下一位 + # encoder_hidden_states: # [bsz, 77, 768], 其中第1位是task prompt的embedding, 第二位开始可以是object prompt的embedding, 最后一位应该保留原始embedding + + + # Predict the noise residual + noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states) + noise_pred = noise_pred.sample + + + + attention_maps = [] + exemplar_attention_maps = [] + + # cross attention + for res in [32, 16]: + attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77] + prompts=[self.config.prompt for i in range(bsz)], # 这里要改么 + attention_store=self.controller, + res=res, + from_where=("up", "down"), + is_cross=True, + select=0 + ) + + task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res] + attention_maps.append(task_attn_) + try: + exemplar_attns = attn_aggregate[:, :, (task_loc_idx[0, 1]+1):(task_loc_idx[0, 1]+1+len_)].unsqueeze(0) # 取exemplar的attn + except: + print(n_instance, len_) + exemplar_attns = torch.permute(exemplar_attns, (0, 3, 1, 2)) # [1, len_, res, res] + exemplar_attention_maps.append(exemplar_attns) + + + + scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))] + attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))]) + task_attn_64 = torch.mean(attns, dim=0, keepdim=True) + + try: + scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))] + attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))]) + except: + print("exemplar_attention_maps shape mismatch, n_instance: {}, len_: {}".format(n_instance, len_)) + print(exemplar_attention_maps[0].shape) + print(exemplar_attention_maps[1].shape) + print(exemplar_attention_maps[2].shape) + exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True) + + + task_attn_64 = (task_attn_64 - task_attn_64.min()) / (task_attn_64.max() - task_attn_64.min() + 1e-6) + exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6) + + task_cross_attention.append(task_attn_64) + instances_cross_attention.append(exemplar_attn_64) + + task_cross_attention = torch.cat(task_cross_attention, dim=0) # [n_forward, 1, 64, 64] + task_cross_attention = torch.mean(task_cross_attention, dim=0, keepdim=True) # [1, 1, 64, 64] + instances_cross_attention = torch.cat(instances_cross_attention, dim=1) # [1, n_instance, 64, 64] + assert instances_cross_attention.shape[1] == n_instance, "instances_cross_attention shape mismatch" + attn_stack = [task_cross_attention / 2, instances_cross_attention] + attn_stack = torch.cat(attn_stack, dim=1) + + del exemplar_attention_maps, attention_maps, attns, task_attn_64, exemplar_attn_64, latents + del input_ids_, input_ids, attention_mask, encoder_hidden_states, timesteps, noisy_latents + del loca_out, loca_feature_bf_regression, adapted_emb + torch.cuda.empty_cache() + + return { + "task_attn_64":task_cross_attention, + "exemplar_attn_64": instances_cross_attention, + "noise_pred":noise_pred, + "noise":noise, + "attn_stack": attn_stack, + "feature_list": feature_list, + } + + + def common_step(self, batch): + mask = batch["mask_t"].to(torch.float32).to(self.device) + if mask.dim() == 3: + mask = mask.unsqueeze(0) + + image_stable = batch["image_stable"] + boxes = batch["boxes"] + input_image = batch["img_enc"] + input_image = input_image.to(self.device) + image_stable = image_stable.to(self.device) + keep_boxes = None + if image_stable.dim() == 4: + image_stable = image_stable.unsqueeze(0) + if input_image.dim() == 4: + input_image = input_image.unsqueeze(0) + + + # segmentation part + n_frames = mask.shape[1] + masks_pred = [] + + for i in range(n_frames): + mask_ = mask[:, i, :, :].unsqueeze(0) # [1, 1, H, W] + mask_ = F.interpolate(mask_.float(), size=(512, 512), mode='nearest') # [1, 1, 512, 512] + mask_ = mask_.to(torch.int64).squeeze(0).detach().to(self.device) # [1, 512, 512] + masks_pred.append(mask_) + del mask_ + + + + # if True: + attns_emb = [] + for i in range(n_frames): + image_stable_prev = image_stable[:, max(0, i-1), :, :, :] + image_stable_after = image_stable[:, min(n_frames-1, i+1), :, :, :] + input_image_curr = input_image[:, i, :, :, :] + + mask_ = masks_pred[i].detach() + unique_labels = torch.unique(mask_) # tensor([0, 1, 2, ...]) + boxes_all = [] + for label in unique_labels: + if label.item() == 0: + continue + binary_mask = (mask_[0] == label).to(torch.uint8) # [H, W] + + # 找非零点坐标 + y_coords, x_coords = torch.nonzero(binary_mask, as_tuple=True) + if len(x_coords) == 0 or len(y_coords) == 0: + continue + x_min = torch.min(x_coords) + y_min = torch.min(y_coords) + x_max = torch.max(x_coords) + y_max = torch.max(y_coords) + boxes_all.append([x_min.item(), y_min.item(), x_max.item(), y_max.item()]) + boxes_all_t = torch.tensor(boxes_all, dtype=torch.float32).unsqueeze(0) + + + output_prev = self.forward_boxes(image_stable_prev, boxes_all_t, input_image_curr) + attn_prev = output_prev["exemplar_attn_64"] + feature_list_prev = output_prev["feature_list"] + + output_after = self.forward_boxes(image_stable_after, boxes_all_t, input_image_curr) + # attn_stack = output["attn_stack"] + attn_after = output_after["exemplar_attn_64"] # [1, n_instance, 64, 64] + feature_list_after = output_after["feature_list"] # [1, n_channels, res, res] + + attn_prev = torch.permute(attn_prev, (1, 0, 2, 3)) # [n_instance, 1, 64, 64] + attn_after = torch.permute(attn_after, (1, 0, 2, 3)) + attn_emb = self.counting_adapter.regressor(attn_prev, feature_list_prev, attn_after, feature_list_after) + attns_emb.append(attn_emb.detach()) + + attns_emb = torch.cat(attns_emb, dim=1) # [1, n_instance, 4] + # tracking part + + feats = batch["features_t"] + coords = batch["coords_t"] + + with torch.no_grad(): + + A_pred = self.track_model(coords, feats, attn_feat=attns_emb).detach() + + del masks_pred, feats, coords, batch + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + return A_pred + + # @profile + def _predict_batch(self, batch): + feats = batch["features_t"].to(self.device) + coords = batch["coords_t"].to(self.device) + timepoints = batch["timepoints_t"].to(self.device) + # Hack that assumes that all parameters of a model are on the same device + device = next(self.track_model.parameters()).device + feats = feats.unsqueeze(0).to(device) + timepoints = timepoints.unsqueeze(0).to(device) + coords = coords.unsqueeze(0).to(device) + + # Concat timepoints to coordinates + coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2) + batch["coords_t"] = coords + batch["features_t"] = feats + with torch.no_grad(): + A = self.common_step(batch) + torch.cuda.empty_cache() + gc.collect() + + A = self.track_model.normalize_output(A, timepoints, coords) + + # # Spatially far entries should not influence the causal normalization + # dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:]) + # invalid = dist > model.config["spatial_pos_cutoff"] + # A[invalid] = -torch.inf + + A = A.squeeze(0).detach().cpu().numpy() + + del feats, coords, timepoints, batch + + return A + + # @profile + def predict_windows(self, + windows: List[dict], + features: list, + model, + imgs_enc: Optional[np.ndarray] = None, + imgs_stable: Optional[np.ndarray] = None, + intra_window_weight: float = 0, + delta_t: int = 1, + edge_threshold: float = 0.05, + spatial_dim: int = 3, + progbar_class=tqdm, + ) -> dict: + + # first get all objects/coords + time_labels_to_id = dict() + node_properties = list() + max_id = np.sum([len(f.labels) for f in features]) + + all_timepoints = np.concatenate([f.timepoints for f in features]) + all_labels = np.concatenate([f.labels for f in features]) + all_coords = np.concatenate([f.coords for f in features]) + all_coords = all_coords[:, -spatial_dim:] + + for i, (t, la, c) in enumerate(zip(all_timepoints, all_labels, all_coords)): + time_labels_to_id[(t, la)] = i + node_properties.append( + dict( + id=i, + coords=tuple(c), + time=t, + # index=ix, + label=la, + ) + ) + + # create assoc matrix between ids + sp_weights, sp_accum = ( + csr_array((max_id, max_id), dtype=np.float32), + csr_array((max_id, max_id), dtype=np.float32), + ) + + tracemalloc.start() + + for t in progbar_class( + range(len(windows)), + desc="Computing associations", + ): + # This assumes that the samples in the dataset are ordered by time and start at 0. + batch = windows[t] + timepoints = batch["timepoints"] + labels = batch["labels"] + + A = self._predict_batch(batch) + + dt = timepoints[None, :] - timepoints[:, None] + time_mask = np.logical_and(dt <= delta_t, dt > 0) + A[~time_mask] = 0 + ii, jj = np.where(A >= edge_threshold) + + if len(ii) == 0: + continue + + labels_ii = labels[ii] + labels_jj = labels[jj] + ts_ii = timepoints[ii] + ts_jj = timepoints[jj] + nodes_ii = np.array( + tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_ii, labels_ii)) + ) + nodes_jj = np.array( + tuple(time_labels_to_id[(t, lab)] for t, lab in zip(ts_jj, labels_jj)) + ) + + # weight middle parts higher + t_middle = t + (model.config["window"] - 1) / 2 + ddt = timepoints[:, None] - t_middle * np.ones_like(dt) + window_weight = np.exp(-intra_window_weight * ddt**2) # default is 1 + # window_weight = np.exp(4*A) # smooth max + sp_weights[nodes_ii, nodes_jj] += window_weight[ii, jj] * A[ii, jj] + sp_accum[nodes_ii, nodes_jj] += window_weight[ii, jj] + + + del batch, A, ii, jj, labels_ii, labels_jj, ts_ii, ts_jj, nodes_ii, nodes_jj, dt, time_mask + gc.collect() + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + sp_weights_coo = sp_weights.tocoo() + sp_accum_coo = sp_accum.tocoo() + assert np.allclose(sp_weights_coo.col, sp_accum_coo.col) and np.allclose( + sp_weights_coo.row, sp_accum_coo.row + ) + + # Normalize weights by the number of times they were written from different sliding window positions + weights = tuple( + ((i, j), v / a) + for i, j, v, a in zip( + sp_weights_coo.row, + sp_weights_coo.col, + sp_weights_coo.data, + sp_accum_coo.data, + ) + ) + + results = dict() + results["nodes"] = node_properties + results["weights"] = weights + + return results + + + def _predict( + self, + imgs: Union[np.ndarray, da.Array], + masks: Union[np.ndarray, da.Array], + imgs_enc: Optional[np.ndarray] = None, + imgs_stable: Optional[np.ndarray] = None, + boxes: Optional[np.ndarray] = None, + edge_threshold: float = 0.05, + n_workers: int = 0, + normalize_imgs: bool = True, + progbar_class=tqdm, + ): + print("Predicting weights for candidate graph") + if normalize_imgs: + if isinstance(imgs, da.Array): + imgs = imgs.map_blocks(normalize) + else: + imgs = normalize(imgs) + + self.eval() + + features = get_features( + detections=masks, + imgs=imgs, + ndim=self.track_model.config["coord_dim"], + n_workers=n_workers, + progbar_class=progbar_class, + ) + print("Building windows") + windows = build_windows_sd( + features, + imgs_enc=imgs_enc, + imgs_stable=imgs_stable, + boxes=boxes, + imgs=imgs, + masks=masks, + window_size=self.track_model.config["window"], + progbar_class=progbar_class, + ) + + print("Predicting windows") + with torch.no_grad(): + predictions = self.predict_windows( + windows=windows, + features=features, + imgs_enc=imgs_enc, + imgs_stable=imgs_stable, + model=self.track_model, + edge_threshold=edge_threshold, + spatial_dim=masks.ndim - 1, + progbar_class=progbar_class, + ) + + return predictions + + def _track_from_predictions( + self, + predictions, + mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", + use_distance: bool = False, + max_distance: int = 256, + max_neighbors: int = 10, + delta_t: int = 1, + **kwargs, + ): + print("Running greedy tracker") + nodes = predictions["nodes"] + weights = predictions["weights"] + + candidate_graph = build_graph( + nodes=nodes, + weights=weights, + use_distance=use_distance, + max_distance=max_distance, + max_neighbors=max_neighbors, + delta_t=delta_t, + ) + if mode == "greedy": + return track_greedy(candidate_graph) + elif mode == "greedy_nodiv": + return track_greedy(candidate_graph, allow_divisions=False) + elif mode == "ilp": + from models.tra_post_model.trackastra.tracking.ilp import track_ilp + + return track_ilp(candidate_graph, ilp_config="gt", **kwargs) + else: + raise ValueError(f"Tracking mode {mode} does not exist.") + + def track( + self, + file_dir: str, + boxes: Optional[torch.Tensor] = None, + mode: Literal["greedy_nodiv", "greedy", "ilp"] = "greedy", + normalize_imgs: bool = True, + progbar_class=tqdm, + n_workers: int = 0, + dataname: Optional[str] = None, + **kwargs, + ) -> TrackGraph: + """Track objects across time frames. + + This method links segmented objects across time frames using the specified + tracking mode. No hyperparameters need to be chosen beyond the tracking mode. + + Args: + imgs: Input images of shape (T,(Z),Y,X) (numpy or dask array) + masks: Instance segmentation masks of shape (T,(Z),Y,X). + mode: Tracking mode: + - "greedy_nodiv": Fast greedy linking without division + - "greedy": Fast greedy linking with division + - "ilp": Integer Linear Programming based linking (more accurate but slower) + progbar_class: Progress bar class to use. + n_workers: Number of worker processes for feature extraction. + normalize_imgs: Whether to normalize the images. + **kwargs: Additional arguments passed to tracking algorithm. + + Returns: + TrackGraph containing the tracking results. + """ + + self.eval() + imgs, imgs_raw, images_stable, tra_imgs, imgs_01, height, width = load_track_images(file_dir) + # tra_imgs = torch.from_numpy(imgs_).float().to(self.device) + imgs_stable = torch.from_numpy(images_stable).float().to(self.device) + imgs_enc = torch.from_numpy(imgs).float().to(self.device) + + + """get segmentation masks first""" + self.boxes = None + self.adapt_emb = None + masks = [] + for i, (input_image, input_image_stable) in tqdm(enumerate(zip(imgs_enc, imgs_stable))): + input_image = input_image.unsqueeze(0) + input_image_stable = input_image_stable.unsqueeze(0) + if i == 0: + if self.use_box and boxes is not None: + self.boxes = boxes.to(self.device) + else: + self.boxes = None + + with torch.no_grad(): + mask = self.forward_sd(input_image_stable, input_image, self.boxes, height=height, width=width) + masks.append(mask) + + masks = np.stack(masks, axis=0) # (T, H, W) + +# ------------------------- + if not masks.shape == tra_imgs.shape: + raise RuntimeError( + f"Img shape {tra_imgs.shape} and mask shape {masks.shape} do not match." + ) + + if not tra_imgs.ndim == self.track_model.config["coord_dim"] + 1: + raise RuntimeError( + f"images should be a sequence of {self.track_model.config['coord_dim']}D images" + ) + + predictions = self._predict( + tra_imgs, + masks, + imgs_enc=imgs_enc, + imgs_stable=imgs_stable, + boxes=boxes, + normalize_imgs=normalize_imgs, + progbar_class=progbar_class, + n_workers=n_workers, + ) + track_graph = self._track_from_predictions(predictions, mode=mode, **kwargs) + + # ctc_tracks, masks_tracked = graph_to_ctc( + # track_graph, + # masks, + # outdir=f"tracked/{dataname}", + # ) + + return track_graph, masks + + + +def inference(data_path, box=None): + if box is not None: + use_box = True + else: + use_box = False + + model = TrackingModule(use_box=use_box) + load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_tra.pth"), strict=True) + + model.move_to_device(torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')) + + + track_graph, masks = model.track(file_dir=data_path, dataname="inference_sequence") + + if not os.path.exists(f"tracked_ours_seg_pred3/"): + os.makedirs(f"tracked_ours_seg_pred3/") + ctc_tracks, masks_tracked = graph_to_ctc( + track_graph, + masks, + outdir=f"tracked_ours_seg_pred3/", + ) + +if __name__ == "__main__": + inference(data_path="example_imgs/2D+Time/Fluo-N2DL-HeLa/train/Fluo-N2DL-HeLa/02")