Spaces:
Sleeping
Sleeping
phoebehxf
commited on
Commit
·
aff3c6f
1
Parent(s):
01050f6
init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- README.md +5 -4
- _utils/attn_utils.py +592 -0
- _utils/attn_utils_new.py +610 -0
- _utils/config.yaml +15 -0
- _utils/example_config.yaml +20 -0
- _utils/load_models.py +16 -0
- _utils/load_track_data.py +118 -0
- _utils/misc_helper.py +37 -0
- _utils/seg_eval.py +61 -0
- _utils/track_args.py +157 -0
- app.py +1638 -0
- config.py +44 -0
- counting.py +340 -0
- example_imgs/cnt/047cell.png +3 -0
- example_imgs/cnt/62_10.png +3 -0
- example_imgs/cnt/6800-17000_GTEX-XQ3S_Adipose-Subcutaneous.png +3 -0
- example_imgs/seg/003_img.png +3 -0
- example_imgs/seg/1-23 [Scan I08].png +3 -0
- example_imgs/seg/10X_B2_Tile-15.aligned.png +3 -0
- example_imgs/seg/1977_Well_F-5_Field_1.png +3 -0
- example_imgs/seg/200972823[5179]_RhoGGG_YAP_TAZ [200972823 Well K6 Field #2].png +3 -0
- example_imgs/seg/A172_Phase_C7_1_00d00h00m_1.png +3 -0
- example_imgs/seg/JE2NileRed_oilp22_PMP_101220_011_NR.png +3 -0
- example_imgs/seg/OpenTest_031.png +3 -0
- example_imgs/seg/X_24.png +3 -0
- example_imgs/seg/exp_A01_G002_0001.oir.png +3 -0
- example_imgs/tra/tracking_test_sequence.zip +3 -0
- example_imgs/tra/tracking_test_sequence2.zip +3 -0
- inference_count.py +237 -0
- inference_seg.py +87 -0
- inference_track.py +202 -0
- models/.DS_Store +0 -0
- models/enc_model/__init__.py +0 -0
- models/enc_model/backbone.py +64 -0
- models/enc_model/loca.py +232 -0
- models/enc_model/loca_args.py +44 -0
- models/enc_model/mlp.py +23 -0
- models/enc_model/ope.py +245 -0
- models/enc_model/positional_encoding.py +30 -0
- models/enc_model/regression_head.py +92 -0
- models/enc_model/transformer.py +94 -0
- models/enc_model/unet_parts.py +77 -0
- models/model.py +653 -0
- models/seg_post_model/cellpose/__init__.py +1 -0
- models/seg_post_model/cellpose/__main__.py +272 -0
- models/seg_post_model/cellpose/cli.py +240 -0
- models/seg_post_model/cellpose/core.py +322 -0
- models/seg_post_model/cellpose/denoise.py +1474 -0
- models/seg_post_model/cellpose/dynamics.py +691 -0
- models/seg_post_model/cellpose/export.py +405 -0
README.md
CHANGED
|
@@ -1,11 +1,12 @@
|
|
| 1 |
---
|
| 2 |
title: MicroscopyMatching
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version:
|
| 8 |
app_file: app.py
|
|
|
|
| 9 |
pinned: false
|
| 10 |
---
|
| 11 |
|
|
|
|
| 1 |
---
|
| 2 |
title: MicroscopyMatching
|
| 3 |
+
emoji: 🚀
|
| 4 |
+
colorFrom: gray
|
| 5 |
+
colorTo: red
|
| 6 |
sdk: gradio
|
| 7 |
+
sdk_version: 5.49.1
|
| 8 |
app_file: app.py
|
| 9 |
+
python_version: 3.11
|
| 10 |
pinned: false
|
| 11 |
---
|
| 12 |
|
_utils/attn_utils.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from IPython.display import display
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing import Union, Tuple, List
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
import math
|
| 11 |
+
from torch import nn, einsum
|
| 12 |
+
from inspect import isfunction
|
| 13 |
+
from diffusers.utils import logging
|
| 14 |
+
try:
|
| 15 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
| 16 |
+
except:
|
| 17 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
| 18 |
+
|
| 19 |
+
try:
|
| 20 |
+
from diffusers.models.cross_attention import CrossAttention
|
| 21 |
+
except:
|
| 22 |
+
from diffusers.models.attention_processor import Attention as CrossAttention
|
| 23 |
+
|
| 24 |
+
MAX_NUM_WORDS = 77
|
| 25 |
+
LOW_RESOURCE = False
|
| 26 |
+
|
| 27 |
+
class CountingCrossAttnProcessor1:
|
| 28 |
+
|
| 29 |
+
def __init__(self, attnstore, place_in_unet):
|
| 30 |
+
super().__init__()
|
| 31 |
+
self.attnstore = attnstore
|
| 32 |
+
self.place_in_unet = place_in_unet
|
| 33 |
+
|
| 34 |
+
def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 35 |
+
batch_size, sequence_length, dim = hidden_states.shape
|
| 36 |
+
h = attn_layer.heads
|
| 37 |
+
q = attn_layer.to_q(hidden_states)
|
| 38 |
+
is_cross = encoder_hidden_states is not None
|
| 39 |
+
context = encoder_hidden_states if is_cross else hidden_states
|
| 40 |
+
k = attn_layer.to_k(context)
|
| 41 |
+
v = attn_layer.to_v(context)
|
| 42 |
+
# q = attn_layer.reshape_heads_to_batch_dim(q)
|
| 43 |
+
# k = attn_layer.reshape_heads_to_batch_dim(k)
|
| 44 |
+
# v = attn_layer.reshape_heads_to_batch_dim(v)
|
| 45 |
+
# q = attn_layer.head_to_batch_dim(q)
|
| 46 |
+
# k = attn_layer.head_to_batch_dim(k)
|
| 47 |
+
# v = attn_layer.head_to_batch_dim(v)
|
| 48 |
+
q = self.head_to_batch_dim(q, h)
|
| 49 |
+
k = self.head_to_batch_dim(k, h)
|
| 50 |
+
v = self.head_to_batch_dim(v, h)
|
| 51 |
+
|
| 52 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale
|
| 53 |
+
|
| 54 |
+
if attention_mask is not None:
|
| 55 |
+
attention_mask = attention_mask.reshape(batch_size, -1)
|
| 56 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 57 |
+
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
|
| 58 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
| 59 |
+
|
| 60 |
+
# attention, what we cannot get enough of
|
| 61 |
+
attn_ = sim.softmax(dim=-1).clone()
|
| 62 |
+
# softmax = nn.Softmax(dim=-1)
|
| 63 |
+
# attn_ = softmax(sim)
|
| 64 |
+
self.attnstore(attn_, is_cross, self.place_in_unet)
|
| 65 |
+
out = torch.einsum("b i j, b j d -> b i d", attn_, v)
|
| 66 |
+
# out = attn_layer.batch_to_head_dim(out)
|
| 67 |
+
out = self.batch_to_head_dim(out, h)
|
| 68 |
+
|
| 69 |
+
if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
|
| 70 |
+
to_out = attn_layer.to_out[0]
|
| 71 |
+
else:
|
| 72 |
+
to_out = attn_layer.to_out
|
| 73 |
+
|
| 74 |
+
out = to_out(out)
|
| 75 |
+
return out
|
| 76 |
+
|
| 77 |
+
def batch_to_head_dim(self, tensor, head_size):
|
| 78 |
+
# head_size = self.heads
|
| 79 |
+
batch_size, seq_len, dim = tensor.shape
|
| 80 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 81 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
| 82 |
+
return tensor
|
| 83 |
+
|
| 84 |
+
def head_to_batch_dim(self, tensor, head_size, out_dim=3):
|
| 85 |
+
# head_size = self.heads
|
| 86 |
+
batch_size, seq_len, dim = tensor.shape
|
| 87 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
| 88 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
| 89 |
+
|
| 90 |
+
if out_dim == 3:
|
| 91 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
| 92 |
+
|
| 93 |
+
return tensor
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def register_attention_control(model, controller):
|
| 97 |
+
|
| 98 |
+
attn_procs = {}
|
| 99 |
+
cross_att_count = 0
|
| 100 |
+
for name in model.unet.attn_processors.keys():
|
| 101 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
|
| 102 |
+
if name.startswith("mid_block"):
|
| 103 |
+
hidden_size = model.unet.config.block_out_channels[-1]
|
| 104 |
+
place_in_unet = "mid"
|
| 105 |
+
elif name.startswith("up_blocks"):
|
| 106 |
+
block_id = int(name[len("up_blocks.")])
|
| 107 |
+
hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
|
| 108 |
+
place_in_unet = "up"
|
| 109 |
+
elif name.startswith("down_blocks"):
|
| 110 |
+
block_id = int(name[len("down_blocks.")])
|
| 111 |
+
hidden_size = model.unet.config.block_out_channels[block_id]
|
| 112 |
+
place_in_unet = "down"
|
| 113 |
+
else:
|
| 114 |
+
continue
|
| 115 |
+
|
| 116 |
+
cross_att_count += 1
|
| 117 |
+
# attn_procs[name] = AttendExciteCrossAttnProcessor(
|
| 118 |
+
# attnstore=controller, place_in_unet=place_in_unet
|
| 119 |
+
# )
|
| 120 |
+
attn_procs[name] = CountingCrossAttnProcessor1(
|
| 121 |
+
attnstore=controller, place_in_unet=place_in_unet
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
model.unet.set_attn_processor(attn_procs)
|
| 125 |
+
controller.num_att_layers = cross_att_count
|
| 126 |
+
|
| 127 |
+
def register_hier_output(model):
|
| 128 |
+
self = model.unet
|
| 129 |
+
from ldm.modules.diffusionmodules.util import checkpoint, timestep_embedding
|
| 130 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 131 |
+
def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None,
|
| 132 |
+
attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None,
|
| 133 |
+
mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True):
|
| 134 |
+
|
| 135 |
+
out_list = []
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 139 |
+
|
| 140 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 141 |
+
forward_upsample_size = False
|
| 142 |
+
upsample_size = None
|
| 143 |
+
|
| 144 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 145 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 146 |
+
forward_upsample_size = True
|
| 147 |
+
|
| 148 |
+
if attention_mask is not None:
|
| 149 |
+
# assume that mask is expressed as:
|
| 150 |
+
# (1 = keep, 0 = discard)
|
| 151 |
+
# convert mask into a bias that can be added to attention scores:
|
| 152 |
+
# (keep = +0, discard = -10000.0)
|
| 153 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 154 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 155 |
+
|
| 156 |
+
if encoder_attention_mask is not None:
|
| 157 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 158 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 159 |
+
|
| 160 |
+
if self.config.center_input_sample:
|
| 161 |
+
sample = 2 * sample - 1.0
|
| 162 |
+
|
| 163 |
+
timesteps = timestep
|
| 164 |
+
if not torch.is_tensor(timesteps):
|
| 165 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 166 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 167 |
+
is_mps = sample.device.type == "mps"
|
| 168 |
+
if isinstance(timestep, float):
|
| 169 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 170 |
+
else:
|
| 171 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 172 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 173 |
+
elif len(timesteps.shape) == 0:
|
| 174 |
+
timesteps = timesteps[None].to(sample.device)
|
| 175 |
+
|
| 176 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 177 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 178 |
+
|
| 179 |
+
t_emb = self.time_proj(timesteps)
|
| 180 |
+
|
| 181 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 182 |
+
|
| 183 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 184 |
+
aug_emb = None
|
| 185 |
+
|
| 186 |
+
if self.class_embedding is not None:
|
| 187 |
+
if class_labels is None:
|
| 188 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 189 |
+
|
| 190 |
+
if self.config.class_embed_type == "timestep":
|
| 191 |
+
class_labels = self.time_proj(class_labels)
|
| 192 |
+
|
| 193 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 194 |
+
# there might be better ways to encapsulate this.
|
| 195 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
| 196 |
+
|
| 197 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
| 198 |
+
|
| 199 |
+
if self.config.class_embeddings_concat:
|
| 200 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
| 201 |
+
else:
|
| 202 |
+
emb = emb + class_emb
|
| 203 |
+
|
| 204 |
+
if self.config.addition_embed_type == "text":
|
| 205 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 206 |
+
elif self.config.addition_embed_type == "text_image":
|
| 207 |
+
# Kandinsky 2.1 - style
|
| 208 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 209 |
+
raise ValueError(
|
| 210 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 211 |
+
)
|
| 212 |
+
|
| 213 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 214 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
| 215 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
| 216 |
+
elif self.config.addition_embed_type == "text_time":
|
| 217 |
+
# SDXL - style
|
| 218 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 219 |
+
raise ValueError(
|
| 220 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 221 |
+
)
|
| 222 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 223 |
+
if "time_ids" not in added_cond_kwargs:
|
| 224 |
+
raise ValueError(
|
| 225 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 226 |
+
)
|
| 227 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 228 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 229 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 230 |
+
|
| 231 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 232 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 233 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 234 |
+
elif self.config.addition_embed_type == "image":
|
| 235 |
+
# Kandinsky 2.2 - style
|
| 236 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 239 |
+
)
|
| 240 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 241 |
+
aug_emb = self.add_embedding(image_embs)
|
| 242 |
+
elif self.config.addition_embed_type == "image_hint":
|
| 243 |
+
# Kandinsky 2.2 - style
|
| 244 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
| 245 |
+
raise ValueError(
|
| 246 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
| 247 |
+
)
|
| 248 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 249 |
+
hint = added_cond_kwargs.get("hint")
|
| 250 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
| 251 |
+
sample = torch.cat([sample, hint], dim=1)
|
| 252 |
+
|
| 253 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 254 |
+
|
| 255 |
+
if self.time_embed_act is not None:
|
| 256 |
+
emb = self.time_embed_act(emb)
|
| 257 |
+
|
| 258 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
| 259 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
| 260 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
| 261 |
+
# Kadinsky 2.1 - style
|
| 262 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 263 |
+
raise ValueError(
|
| 264 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 265 |
+
)
|
| 266 |
+
|
| 267 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 268 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
| 269 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
| 270 |
+
# Kandinsky 2.2 - style
|
| 271 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 272 |
+
raise ValueError(
|
| 273 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 274 |
+
)
|
| 275 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 276 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
| 277 |
+
# 2. pre-process
|
| 278 |
+
sample = self.conv_in(sample) # 1, 320, 64, 64
|
| 279 |
+
|
| 280 |
+
# 2.5 GLIGEN position net
|
| 281 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
| 282 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
| 283 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
| 284 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
| 285 |
+
|
| 286 |
+
# 3. down
|
| 287 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
| 288 |
+
|
| 289 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
| 290 |
+
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
| 291 |
+
|
| 292 |
+
down_block_res_samples = (sample,)
|
| 293 |
+
|
| 294 |
+
for downsample_block in self.down_blocks:
|
| 295 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 296 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
| 297 |
+
additional_residuals = {}
|
| 298 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
| 299 |
+
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
| 300 |
+
|
| 301 |
+
sample, res_samples = downsample_block(
|
| 302 |
+
hidden_states=sample,
|
| 303 |
+
temb=emb,
|
| 304 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 305 |
+
attention_mask=attention_mask,
|
| 306 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 307 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 308 |
+
**additional_residuals,
|
| 309 |
+
)
|
| 310 |
+
else:
|
| 311 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
| 312 |
+
|
| 313 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
| 314 |
+
sample += down_block_additional_residuals.pop(0)
|
| 315 |
+
|
| 316 |
+
down_block_res_samples += res_samples
|
| 317 |
+
|
| 318 |
+
if is_controlnet:
|
| 319 |
+
new_down_block_res_samples = ()
|
| 320 |
+
|
| 321 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 322 |
+
down_block_res_samples, down_block_additional_residuals
|
| 323 |
+
):
|
| 324 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 325 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
| 326 |
+
|
| 327 |
+
down_block_res_samples = new_down_block_res_samples
|
| 328 |
+
|
| 329 |
+
# 4. mid
|
| 330 |
+
if self.mid_block is not None:
|
| 331 |
+
sample = self.mid_block(
|
| 332 |
+
sample,
|
| 333 |
+
emb,
|
| 334 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 335 |
+
attention_mask=attention_mask,
|
| 336 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 337 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 338 |
+
)
|
| 339 |
+
# To support T2I-Adapter-XL
|
| 340 |
+
if (
|
| 341 |
+
is_adapter
|
| 342 |
+
and len(down_block_additional_residuals) > 0
|
| 343 |
+
and sample.shape == down_block_additional_residuals[0].shape
|
| 344 |
+
):
|
| 345 |
+
sample += down_block_additional_residuals.pop(0)
|
| 346 |
+
|
| 347 |
+
if is_controlnet:
|
| 348 |
+
sample = sample + mid_block_additional_residual
|
| 349 |
+
|
| 350 |
+
# 5. up
|
| 351 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 352 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 353 |
+
|
| 354 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 355 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 356 |
+
|
| 357 |
+
# if we have not reached the final block and need to forward the
|
| 358 |
+
# upsample size, we do it here
|
| 359 |
+
if not is_final_block and forward_upsample_size:
|
| 360 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 361 |
+
|
| 362 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 363 |
+
sample = upsample_block(
|
| 364 |
+
hidden_states=sample,
|
| 365 |
+
temb=emb,
|
| 366 |
+
res_hidden_states_tuple=res_samples,
|
| 367 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 368 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 369 |
+
upsample_size=upsample_size,
|
| 370 |
+
attention_mask=attention_mask,
|
| 371 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 372 |
+
)
|
| 373 |
+
else:
|
| 374 |
+
sample = upsample_block(
|
| 375 |
+
hidden_states=sample,
|
| 376 |
+
temb=emb,
|
| 377 |
+
res_hidden_states_tuple=res_samples,
|
| 378 |
+
upsample_size=upsample_size,
|
| 379 |
+
scale=lora_scale,
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
# if i in [1, 4, 7]:
|
| 383 |
+
out_list.append(sample)
|
| 384 |
+
|
| 385 |
+
# 6. post-process
|
| 386 |
+
if self.conv_norm_out:
|
| 387 |
+
sample = self.conv_norm_out(sample)
|
| 388 |
+
sample = self.conv_act(sample)
|
| 389 |
+
sample = self.conv_out(sample)
|
| 390 |
+
|
| 391 |
+
if not return_dict:
|
| 392 |
+
return (sample,)
|
| 393 |
+
|
| 394 |
+
return UNet2DConditionOutput(sample=sample), out_list
|
| 395 |
+
|
| 396 |
+
self.forward = forward
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
class AttentionControl(abc.ABC):
|
| 400 |
+
|
| 401 |
+
def step_callback(self, x_t):
|
| 402 |
+
return x_t
|
| 403 |
+
|
| 404 |
+
def between_steps(self):
|
| 405 |
+
return
|
| 406 |
+
|
| 407 |
+
@property
|
| 408 |
+
def num_uncond_att_layers(self):
|
| 409 |
+
return 0
|
| 410 |
+
|
| 411 |
+
@abc.abstractmethod
|
| 412 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 413 |
+
raise NotImplementedError
|
| 414 |
+
|
| 415 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
| 416 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
| 417 |
+
# self.forward(attn, is_cross, place_in_unet)
|
| 418 |
+
if LOW_RESOURCE:
|
| 419 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
| 420 |
+
else:
|
| 421 |
+
h = attn.shape[0]
|
| 422 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
| 423 |
+
self.cur_att_layer += 1
|
| 424 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
| 425 |
+
self.cur_att_layer = 0
|
| 426 |
+
self.cur_step += 1
|
| 427 |
+
self.between_steps()
|
| 428 |
+
return attn
|
| 429 |
+
|
| 430 |
+
def reset(self):
|
| 431 |
+
self.cur_step = 0
|
| 432 |
+
self.cur_att_layer = 0
|
| 433 |
+
|
| 434 |
+
def __init__(self):
|
| 435 |
+
self.cur_step = 0
|
| 436 |
+
self.num_att_layers = -1
|
| 437 |
+
self.cur_att_layer = 0
|
| 438 |
+
|
| 439 |
+
|
| 440 |
+
class EmptyControl(AttentionControl):
|
| 441 |
+
|
| 442 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 443 |
+
return attn
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
class AttentionStore(AttentionControl):
|
| 447 |
+
|
| 448 |
+
@staticmethod
|
| 449 |
+
def get_empty_store():
|
| 450 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
| 451 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
| 452 |
+
|
| 453 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 454 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
| 455 |
+
if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead
|
| 456 |
+
self.step_store[key].append(attn)
|
| 457 |
+
return attn
|
| 458 |
+
|
| 459 |
+
def between_steps(self):
|
| 460 |
+
self.attention_store = self.step_store
|
| 461 |
+
if self.save_global_store:
|
| 462 |
+
with torch.no_grad():
|
| 463 |
+
if len(self.global_store) == 0:
|
| 464 |
+
self.global_store = self.step_store
|
| 465 |
+
else:
|
| 466 |
+
for key in self.global_store:
|
| 467 |
+
for i in range(len(self.global_store[key])):
|
| 468 |
+
self.global_store[key][i] += self.step_store[key][i].detach()
|
| 469 |
+
self.step_store = self.get_empty_store()
|
| 470 |
+
self.step_store = self.get_empty_store()
|
| 471 |
+
|
| 472 |
+
def get_average_attention(self):
|
| 473 |
+
average_attention = self.attention_store
|
| 474 |
+
return average_attention
|
| 475 |
+
|
| 476 |
+
def get_average_global_attention(self):
|
| 477 |
+
average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
|
| 478 |
+
self.attention_store}
|
| 479 |
+
return average_attention
|
| 480 |
+
|
| 481 |
+
def reset(self):
|
| 482 |
+
super(AttentionStore, self).reset()
|
| 483 |
+
self.step_store = self.get_empty_store()
|
| 484 |
+
self.attention_store = {}
|
| 485 |
+
self.global_store = {}
|
| 486 |
+
|
| 487 |
+
def __init__(self, max_size=32, save_global_store=False):
|
| 488 |
+
'''
|
| 489 |
+
Initialize an empty AttentionStore
|
| 490 |
+
:param step_index: used to visualize only a specific step in the diffusion process
|
| 491 |
+
'''
|
| 492 |
+
super(AttentionStore, self).__init__()
|
| 493 |
+
self.save_global_store = save_global_store
|
| 494 |
+
self.max_size = max_size
|
| 495 |
+
self.step_store = self.get_empty_store()
|
| 496 |
+
self.attention_store = {}
|
| 497 |
+
self.global_store = {}
|
| 498 |
+
self.curr_step_index = 0
|
| 499 |
+
|
| 500 |
+
def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
|
| 501 |
+
out = []
|
| 502 |
+
attention_maps = attention_store.get_average_attention()
|
| 503 |
+
num_pixels = res ** 2
|
| 504 |
+
for location in from_where:
|
| 505 |
+
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
|
| 506 |
+
if item.shape[1] == num_pixels:
|
| 507 |
+
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
|
| 508 |
+
out.append(cross_maps)
|
| 509 |
+
out = torch.cat(out, dim=0)
|
| 510 |
+
out = out.sum(0) / out.shape[0]
|
| 511 |
+
return out
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
|
| 515 |
+
tokens = tokenizer.encode(prompts[select])
|
| 516 |
+
decoder = tokenizer.decode
|
| 517 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
|
| 518 |
+
images = []
|
| 519 |
+
for i in range(len(tokens)):
|
| 520 |
+
image = attention_maps[:, :, i]
|
| 521 |
+
image = 255 * image / image.max()
|
| 522 |
+
image = image.unsqueeze(-1).expand(*image.shape, 3)
|
| 523 |
+
image = image.numpy().astype(np.uint8)
|
| 524 |
+
image = np.array(Image.fromarray(image).resize((256, 256)))
|
| 525 |
+
image = text_under_image(image, decoder(int(tokens[i])))
|
| 526 |
+
images.append(image)
|
| 527 |
+
view_images(np.stack(images, axis=0))
|
| 528 |
+
|
| 529 |
+
|
| 530 |
+
def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
|
| 531 |
+
max_com=10, select: int = 0):
|
| 532 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
|
| 533 |
+
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
|
| 534 |
+
images = []
|
| 535 |
+
for i in range(max_com):
|
| 536 |
+
image = vh[i].reshape(res, res)
|
| 537 |
+
image = image - image.min()
|
| 538 |
+
image = 255 * image / image.max()
|
| 539 |
+
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
|
| 540 |
+
image = Image.fromarray(image).resize((256, 256))
|
| 541 |
+
image = np.array(image)
|
| 542 |
+
images.append(image)
|
| 543 |
+
view_images(np.concatenate(images, axis=1))
|
| 544 |
+
|
| 545 |
+
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
|
| 546 |
+
h, w, c = image.shape
|
| 547 |
+
offset = int(h * .2)
|
| 548 |
+
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
|
| 549 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 550 |
+
# font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
|
| 551 |
+
img[:h] = image
|
| 552 |
+
textsize = cv2.getTextSize(text, font, 1, 2)[0]
|
| 553 |
+
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
|
| 554 |
+
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
|
| 555 |
+
return img
|
| 556 |
+
|
| 557 |
+
|
| 558 |
+
def view_images(images, num_rows=1, offset_ratio=0.02):
|
| 559 |
+
if type(images) is list:
|
| 560 |
+
num_empty = len(images) % num_rows
|
| 561 |
+
elif images.ndim == 4:
|
| 562 |
+
num_empty = images.shape[0] % num_rows
|
| 563 |
+
else:
|
| 564 |
+
images = [images]
|
| 565 |
+
num_empty = 0
|
| 566 |
+
|
| 567 |
+
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
|
| 568 |
+
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
|
| 569 |
+
num_items = len(images)
|
| 570 |
+
|
| 571 |
+
h, w, c = images[0].shape
|
| 572 |
+
offset = int(h * offset_ratio)
|
| 573 |
+
num_cols = num_items // num_rows
|
| 574 |
+
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
|
| 575 |
+
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
|
| 576 |
+
for i in range(num_rows):
|
| 577 |
+
for j in range(num_cols):
|
| 578 |
+
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
|
| 579 |
+
i * num_cols + j]
|
| 580 |
+
|
| 581 |
+
pil_img = Image.fromarray(image_)
|
| 582 |
+
display(pil_img)
|
| 583 |
+
|
| 584 |
+
def self_cross_attn(self_attn, cross_attn):
|
| 585 |
+
res = self_attn.shape[0]
|
| 586 |
+
assert res == cross_attn.shape[0]
|
| 587 |
+
# cross attn [res, res] -> [res*res]
|
| 588 |
+
cross_attn_ = cross_attn.reshape([res*res])
|
| 589 |
+
# self_attn [res, res, res*res]
|
| 590 |
+
self_cross_attn = cross_attn_ * self_attn
|
| 591 |
+
self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0)
|
| 592 |
+
return self_cross_attn
|
_utils/attn_utils_new.py
ADDED
|
@@ -0,0 +1,610 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import abc
|
| 2 |
+
|
| 3 |
+
import cv2
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
from IPython.display import display
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from typing import Union, Tuple, List
|
| 9 |
+
from einops import rearrange, repeat
|
| 10 |
+
import math
|
| 11 |
+
from torch import nn, einsum
|
| 12 |
+
from inspect import isfunction
|
| 13 |
+
from diffusers.utils import logging
|
| 14 |
+
try:
|
| 15 |
+
from diffusers.models.unet_2d_condition import UNet2DConditionOutput
|
| 16 |
+
except:
|
| 17 |
+
from diffusers.models.unets.unet_2d_condition import UNet2DConditionOutput
|
| 18 |
+
try:
|
| 19 |
+
from diffusers.models.cross_attention import CrossAttention
|
| 20 |
+
except:
|
| 21 |
+
from diffusers.models.attention_processor import Attention as CrossAttention
|
| 22 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 23 |
+
MAX_NUM_WORDS = 77
|
| 24 |
+
LOW_RESOURCE = False
|
| 25 |
+
|
| 26 |
+
class CountingCrossAttnProcessor1:
|
| 27 |
+
|
| 28 |
+
def __init__(self, attnstore, place_in_unet):
|
| 29 |
+
super().__init__()
|
| 30 |
+
self.attnstore = attnstore
|
| 31 |
+
self.place_in_unet = place_in_unet
|
| 32 |
+
|
| 33 |
+
def __call__(self, attn_layer: CrossAttention, hidden_states, encoder_hidden_states=None, attention_mask=None):
|
| 34 |
+
batch_size, sequence_length, dim = hidden_states.shape
|
| 35 |
+
h = attn_layer.heads
|
| 36 |
+
q = attn_layer.to_q(hidden_states)
|
| 37 |
+
is_cross = encoder_hidden_states is not None
|
| 38 |
+
context = encoder_hidden_states if is_cross else hidden_states
|
| 39 |
+
k = attn_layer.to_k(context)
|
| 40 |
+
v = attn_layer.to_v(context)
|
| 41 |
+
# q = attn_layer.reshape_heads_to_batch_dim(q)
|
| 42 |
+
# k = attn_layer.reshape_heads_to_batch_dim(k)
|
| 43 |
+
# v = attn_layer.reshape_heads_to_batch_dim(v)
|
| 44 |
+
# q = attn_layer.head_to_batch_dim(q)
|
| 45 |
+
# k = attn_layer.head_to_batch_dim(k)
|
| 46 |
+
# v = attn_layer.head_to_batch_dim(v)
|
| 47 |
+
q = self.head_to_batch_dim(q, h)
|
| 48 |
+
k = self.head_to_batch_dim(k, h)
|
| 49 |
+
v = self.head_to_batch_dim(v, h)
|
| 50 |
+
|
| 51 |
+
sim = torch.einsum("b i d, b j d -> b i j", q, k) * attn_layer.scale
|
| 52 |
+
|
| 53 |
+
if attention_mask is not None:
|
| 54 |
+
attention_mask = attention_mask.reshape(batch_size, -1)
|
| 55 |
+
max_neg_value = -torch.finfo(sim.dtype).max
|
| 56 |
+
attention_mask = attention_mask[:, None, :].repeat(h, 1, 1)
|
| 57 |
+
sim.masked_fill_(~attention_mask, max_neg_value)
|
| 58 |
+
|
| 59 |
+
# attention, what we cannot get enough of
|
| 60 |
+
attn_ = sim.softmax(dim=-1).clone()
|
| 61 |
+
# softmax = nn.Softmax(dim=-1)
|
| 62 |
+
# attn_ = softmax(sim)
|
| 63 |
+
self.attnstore(attn_, is_cross, self.place_in_unet)
|
| 64 |
+
out = torch.einsum("b i j, b j d -> b i d", attn_, v)
|
| 65 |
+
# out = attn_layer.batch_to_head_dim(out)
|
| 66 |
+
out = self.batch_to_head_dim(out, h)
|
| 67 |
+
|
| 68 |
+
if type(attn_layer.to_out) is torch.nn.modules.container.ModuleList:
|
| 69 |
+
to_out = attn_layer.to_out[0]
|
| 70 |
+
else:
|
| 71 |
+
to_out = attn_layer.to_out
|
| 72 |
+
|
| 73 |
+
out = to_out(out)
|
| 74 |
+
return out
|
| 75 |
+
|
| 76 |
+
def batch_to_head_dim(self, tensor, head_size):
|
| 77 |
+
# head_size = self.heads
|
| 78 |
+
batch_size, seq_len, dim = tensor.shape
|
| 79 |
+
tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
|
| 80 |
+
tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
|
| 81 |
+
return tensor
|
| 82 |
+
|
| 83 |
+
def head_to_batch_dim(self, tensor, head_size, out_dim=3):
|
| 84 |
+
# head_size = self.heads
|
| 85 |
+
batch_size, seq_len, dim = tensor.shape
|
| 86 |
+
tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
|
| 87 |
+
tensor = tensor.permute(0, 2, 1, 3)
|
| 88 |
+
|
| 89 |
+
if out_dim == 3:
|
| 90 |
+
tensor = tensor.reshape(batch_size * head_size, seq_len, dim // head_size)
|
| 91 |
+
|
| 92 |
+
return tensor
|
| 93 |
+
|
| 94 |
+
|
| 95 |
+
def register_attention_control(model, controller):
|
| 96 |
+
|
| 97 |
+
attn_procs = {}
|
| 98 |
+
cross_att_count = 0
|
| 99 |
+
for name in model.unet.attn_processors.keys():
|
| 100 |
+
cross_attention_dim = None if name.endswith("attn1.processor") else model.unet.config.cross_attention_dim
|
| 101 |
+
if name.startswith("mid_block"):
|
| 102 |
+
hidden_size = model.unet.config.block_out_channels[-1]
|
| 103 |
+
place_in_unet = "mid"
|
| 104 |
+
elif name.startswith("up_blocks"):
|
| 105 |
+
block_id = int(name[len("up_blocks.")])
|
| 106 |
+
hidden_size = list(reversed(model.unet.config.block_out_channels))[block_id]
|
| 107 |
+
place_in_unet = "up"
|
| 108 |
+
elif name.startswith("down_blocks"):
|
| 109 |
+
block_id = int(name[len("down_blocks.")])
|
| 110 |
+
hidden_size = model.unet.config.block_out_channels[block_id]
|
| 111 |
+
place_in_unet = "down"
|
| 112 |
+
else:
|
| 113 |
+
continue
|
| 114 |
+
|
| 115 |
+
cross_att_count += 1
|
| 116 |
+
# attn_procs[name] = AttendExciteCrossAttnProcessor(
|
| 117 |
+
# attnstore=controller, place_in_unet=place_in_unet
|
| 118 |
+
# )
|
| 119 |
+
attn_procs[name] = CountingCrossAttnProcessor1(
|
| 120 |
+
attnstore=controller, place_in_unet=place_in_unet
|
| 121 |
+
)
|
| 122 |
+
|
| 123 |
+
model.unet.set_attn_processor(attn_procs)
|
| 124 |
+
controller.num_att_layers = cross_att_count
|
| 125 |
+
|
| 126 |
+
def register_hier_output(model):
|
| 127 |
+
self = model.unet
|
| 128 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 129 |
+
def forward(sample, timestep=None, encoder_hidden_states=None, class_labels=None, timestep_cond=None,
|
| 130 |
+
attention_mask=None, cross_attention_kwargs=None, added_cond_kwargs=None, down_block_additional_residuals=None,
|
| 131 |
+
mid_block_additional_residual=None, encoder_attention_mask=None, return_dict=True):
|
| 132 |
+
|
| 133 |
+
out_list = []
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 137 |
+
|
| 138 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 139 |
+
forward_upsample_size = False
|
| 140 |
+
upsample_size = None
|
| 141 |
+
|
| 142 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 143 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 144 |
+
forward_upsample_size = True
|
| 145 |
+
|
| 146 |
+
if attention_mask is not None:
|
| 147 |
+
# assume that mask is expressed as:
|
| 148 |
+
# (1 = keep, 0 = discard)
|
| 149 |
+
# convert mask into a bias that can be added to attention scores:
|
| 150 |
+
# (keep = +0, discard = -10000.0)
|
| 151 |
+
attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
|
| 152 |
+
attention_mask = attention_mask.unsqueeze(1)
|
| 153 |
+
|
| 154 |
+
if encoder_attention_mask is not None:
|
| 155 |
+
encoder_attention_mask = (1 - encoder_attention_mask.to(sample.dtype)) * -10000.0
|
| 156 |
+
encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
|
| 157 |
+
|
| 158 |
+
if self.config.center_input_sample:
|
| 159 |
+
sample = 2 * sample - 1.0
|
| 160 |
+
|
| 161 |
+
timesteps = timestep
|
| 162 |
+
if not torch.is_tensor(timesteps):
|
| 163 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 164 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 165 |
+
is_mps = sample.device.type == "mps"
|
| 166 |
+
if isinstance(timestep, float):
|
| 167 |
+
dtype = torch.float32 if is_mps else torch.float64
|
| 168 |
+
else:
|
| 169 |
+
dtype = torch.int32 if is_mps else torch.int64
|
| 170 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 171 |
+
elif len(timesteps.shape) == 0:
|
| 172 |
+
timesteps = timesteps[None].to(sample.device)
|
| 173 |
+
|
| 174 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 175 |
+
timesteps = timesteps.expand(sample.shape[0])
|
| 176 |
+
|
| 177 |
+
t_emb = self.time_proj(timesteps)
|
| 178 |
+
|
| 179 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 180 |
+
|
| 181 |
+
emb = self.time_embedding(t_emb, timestep_cond)
|
| 182 |
+
aug_emb = None
|
| 183 |
+
|
| 184 |
+
if self.class_embedding is not None:
|
| 185 |
+
if class_labels is None:
|
| 186 |
+
raise ValueError("class_labels should be provided when num_class_embeds > 0")
|
| 187 |
+
|
| 188 |
+
if self.config.class_embed_type == "timestep":
|
| 189 |
+
class_labels = self.time_proj(class_labels)
|
| 190 |
+
|
| 191 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 192 |
+
# there might be better ways to encapsulate this.
|
| 193 |
+
class_labels = class_labels.to(dtype=sample.dtype)
|
| 194 |
+
|
| 195 |
+
class_emb = self.class_embedding(class_labels).to(dtype=sample.dtype)
|
| 196 |
+
|
| 197 |
+
if self.config.class_embeddings_concat:
|
| 198 |
+
emb = torch.cat([emb, class_emb], dim=-1)
|
| 199 |
+
else:
|
| 200 |
+
emb = emb + class_emb
|
| 201 |
+
|
| 202 |
+
if self.config.addition_embed_type == "text":
|
| 203 |
+
aug_emb = self.add_embedding(encoder_hidden_states)
|
| 204 |
+
elif self.config.addition_embed_type == "text_image":
|
| 205 |
+
# Kandinsky 2.1 - style
|
| 206 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 209 |
+
)
|
| 210 |
+
|
| 211 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 212 |
+
text_embs = added_cond_kwargs.get("text_embeds", encoder_hidden_states)
|
| 213 |
+
aug_emb = self.add_embedding(text_embs, image_embs)
|
| 214 |
+
elif self.config.addition_embed_type == "text_time":
|
| 215 |
+
# SDXL - style
|
| 216 |
+
if "text_embeds" not in added_cond_kwargs:
|
| 217 |
+
raise ValueError(
|
| 218 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `text_embeds` to be passed in `added_cond_kwargs`"
|
| 219 |
+
)
|
| 220 |
+
text_embeds = added_cond_kwargs.get("text_embeds")
|
| 221 |
+
if "time_ids" not in added_cond_kwargs:
|
| 222 |
+
raise ValueError(
|
| 223 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'text_time' which requires the keyword argument `time_ids` to be passed in `added_cond_kwargs`"
|
| 224 |
+
)
|
| 225 |
+
time_ids = added_cond_kwargs.get("time_ids")
|
| 226 |
+
time_embeds = self.add_time_proj(time_ids.flatten())
|
| 227 |
+
time_embeds = time_embeds.reshape((text_embeds.shape[0], -1))
|
| 228 |
+
|
| 229 |
+
add_embeds = torch.concat([text_embeds, time_embeds], dim=-1)
|
| 230 |
+
add_embeds = add_embeds.to(emb.dtype)
|
| 231 |
+
aug_emb = self.add_embedding(add_embeds)
|
| 232 |
+
elif self.config.addition_embed_type == "image":
|
| 233 |
+
# Kandinsky 2.2 - style
|
| 234 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 235 |
+
raise ValueError(
|
| 236 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
| 237 |
+
)
|
| 238 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 239 |
+
aug_emb = self.add_embedding(image_embs)
|
| 240 |
+
elif self.config.addition_embed_type == "image_hint":
|
| 241 |
+
# Kandinsky 2.2 - style
|
| 242 |
+
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
| 243 |
+
raise ValueError(
|
| 244 |
+
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
| 245 |
+
)
|
| 246 |
+
image_embs = added_cond_kwargs.get("image_embeds")
|
| 247 |
+
hint = added_cond_kwargs.get("hint")
|
| 248 |
+
aug_emb, hint = self.add_embedding(image_embs, hint)
|
| 249 |
+
sample = torch.cat([sample, hint], dim=1)
|
| 250 |
+
|
| 251 |
+
emb = emb + aug_emb if aug_emb is not None else emb
|
| 252 |
+
|
| 253 |
+
if self.time_embed_act is not None:
|
| 254 |
+
emb = self.time_embed_act(emb)
|
| 255 |
+
|
| 256 |
+
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
|
| 257 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
|
| 258 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
|
| 259 |
+
# Kadinsky 2.1 - style
|
| 260 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 261 |
+
raise ValueError(
|
| 262 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 266 |
+
encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states, image_embeds)
|
| 267 |
+
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "image_proj":
|
| 268 |
+
# Kandinsky 2.2 - style
|
| 269 |
+
if "image_embeds" not in added_cond_kwargs:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
|
| 272 |
+
)
|
| 273 |
+
image_embeds = added_cond_kwargs.get("image_embeds")
|
| 274 |
+
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
| 275 |
+
# 2. pre-process
|
| 276 |
+
sample = self.conv_in(sample) # 1, 320, 64, 64
|
| 277 |
+
|
| 278 |
+
# 2.5 GLIGEN position net
|
| 279 |
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
| 280 |
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
| 281 |
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
| 282 |
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
| 283 |
+
|
| 284 |
+
# 3. down
|
| 285 |
+
lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
|
| 286 |
+
|
| 287 |
+
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
| 288 |
+
is_adapter = mid_block_additional_residual is None and down_block_additional_residuals is not None
|
| 289 |
+
|
| 290 |
+
down_block_res_samples = (sample,)
|
| 291 |
+
|
| 292 |
+
for downsample_block in self.down_blocks:
|
| 293 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 294 |
+
# For t2i-adapter CrossAttnDownBlock2D
|
| 295 |
+
additional_residuals = {}
|
| 296 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
| 297 |
+
additional_residuals["additional_residuals"] = down_block_additional_residuals.pop(0)
|
| 298 |
+
|
| 299 |
+
sample, res_samples = downsample_block(
|
| 300 |
+
hidden_states=sample,
|
| 301 |
+
temb=emb,
|
| 302 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 303 |
+
attention_mask=attention_mask,
|
| 304 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 305 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 306 |
+
**additional_residuals,
|
| 307 |
+
)
|
| 308 |
+
else:
|
| 309 |
+
sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
|
| 310 |
+
|
| 311 |
+
if is_adapter and len(down_block_additional_residuals) > 0:
|
| 312 |
+
sample += down_block_additional_residuals.pop(0)
|
| 313 |
+
|
| 314 |
+
down_block_res_samples += res_samples
|
| 315 |
+
|
| 316 |
+
if is_controlnet:
|
| 317 |
+
new_down_block_res_samples = ()
|
| 318 |
+
|
| 319 |
+
for down_block_res_sample, down_block_additional_residual in zip(
|
| 320 |
+
down_block_res_samples, down_block_additional_residuals
|
| 321 |
+
):
|
| 322 |
+
down_block_res_sample = down_block_res_sample + down_block_additional_residual
|
| 323 |
+
new_down_block_res_samples = new_down_block_res_samples + (down_block_res_sample,)
|
| 324 |
+
|
| 325 |
+
down_block_res_samples = new_down_block_res_samples
|
| 326 |
+
|
| 327 |
+
# 4. mid
|
| 328 |
+
if self.mid_block is not None:
|
| 329 |
+
sample = self.mid_block(
|
| 330 |
+
sample,
|
| 331 |
+
emb,
|
| 332 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 333 |
+
attention_mask=attention_mask,
|
| 334 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 335 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 336 |
+
)
|
| 337 |
+
# To support T2I-Adapter-XL
|
| 338 |
+
if (
|
| 339 |
+
is_adapter
|
| 340 |
+
and len(down_block_additional_residuals) > 0
|
| 341 |
+
and sample.shape == down_block_additional_residuals[0].shape
|
| 342 |
+
):
|
| 343 |
+
sample += down_block_additional_residuals.pop(0)
|
| 344 |
+
|
| 345 |
+
if is_controlnet:
|
| 346 |
+
sample = sample + mid_block_additional_residual
|
| 347 |
+
|
| 348 |
+
# 5. up
|
| 349 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 350 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 351 |
+
|
| 352 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 353 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 354 |
+
|
| 355 |
+
# if we have not reached the final block and need to forward the
|
| 356 |
+
# upsample size, we do it here
|
| 357 |
+
if not is_final_block and forward_upsample_size:
|
| 358 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 359 |
+
|
| 360 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 361 |
+
sample = upsample_block(
|
| 362 |
+
hidden_states=sample,
|
| 363 |
+
temb=emb,
|
| 364 |
+
res_hidden_states_tuple=res_samples,
|
| 365 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 366 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 367 |
+
upsample_size=upsample_size,
|
| 368 |
+
attention_mask=attention_mask,
|
| 369 |
+
encoder_attention_mask=encoder_attention_mask,
|
| 370 |
+
)
|
| 371 |
+
else:
|
| 372 |
+
sample = upsample_block(
|
| 373 |
+
hidden_states=sample,
|
| 374 |
+
temb=emb,
|
| 375 |
+
res_hidden_states_tuple=res_samples,
|
| 376 |
+
upsample_size=upsample_size,
|
| 377 |
+
scale=lora_scale,
|
| 378 |
+
)
|
| 379 |
+
|
| 380 |
+
out_list.append(sample)
|
| 381 |
+
|
| 382 |
+
# 6. post-process
|
| 383 |
+
if self.conv_norm_out:
|
| 384 |
+
sample = self.conv_norm_out(sample)
|
| 385 |
+
sample = self.conv_act(sample)
|
| 386 |
+
sample = self.conv_out(sample)
|
| 387 |
+
|
| 388 |
+
if not return_dict:
|
| 389 |
+
return (sample,)
|
| 390 |
+
|
| 391 |
+
return UNet2DConditionOutput(sample=sample), out_list
|
| 392 |
+
|
| 393 |
+
self.forward = forward
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
|
| 401 |
+
class AttentionControl(abc.ABC):
|
| 402 |
+
|
| 403 |
+
def step_callback(self, x_t):
|
| 404 |
+
return x_t
|
| 405 |
+
|
| 406 |
+
def between_steps(self):
|
| 407 |
+
return
|
| 408 |
+
|
| 409 |
+
@property
|
| 410 |
+
def num_uncond_att_layers(self):
|
| 411 |
+
return 0
|
| 412 |
+
|
| 413 |
+
@abc.abstractmethod
|
| 414 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 415 |
+
raise NotImplementedError
|
| 416 |
+
|
| 417 |
+
def __call__(self, attn, is_cross: bool, place_in_unet: str):
|
| 418 |
+
if self.cur_att_layer >= self.num_uncond_att_layers:
|
| 419 |
+
# self.forward(attn, is_cross, place_in_unet)
|
| 420 |
+
if LOW_RESOURCE:
|
| 421 |
+
attn = self.forward(attn, is_cross, place_in_unet)
|
| 422 |
+
else:
|
| 423 |
+
h = attn.shape[0]
|
| 424 |
+
attn[h // 2:] = self.forward(attn[h // 2:], is_cross, place_in_unet)
|
| 425 |
+
self.cur_att_layer += 1
|
| 426 |
+
if self.cur_att_layer == self.num_att_layers + self.num_uncond_att_layers:
|
| 427 |
+
self.cur_att_layer = 0
|
| 428 |
+
self.cur_step += 1
|
| 429 |
+
self.between_steps()
|
| 430 |
+
return attn
|
| 431 |
+
|
| 432 |
+
def reset(self):
|
| 433 |
+
self.cur_step = 0
|
| 434 |
+
self.cur_att_layer = 0
|
| 435 |
+
|
| 436 |
+
def __init__(self):
|
| 437 |
+
self.cur_step = 0
|
| 438 |
+
self.num_att_layers = -1
|
| 439 |
+
self.cur_att_layer = 0
|
| 440 |
+
|
| 441 |
+
|
| 442 |
+
class EmptyControl(AttentionControl):
|
| 443 |
+
|
| 444 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 445 |
+
return attn
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
class AttentionStore(AttentionControl):
|
| 449 |
+
|
| 450 |
+
@staticmethod
|
| 451 |
+
def get_empty_store():
|
| 452 |
+
return {"down_cross": [], "mid_cross": [], "up_cross": [],
|
| 453 |
+
"down_self": [], "mid_self": [], "up_self": []}
|
| 454 |
+
|
| 455 |
+
def forward(self, attn, is_cross: bool, place_in_unet: str):
|
| 456 |
+
key = f"{place_in_unet}_{'cross' if is_cross else 'self'}"
|
| 457 |
+
if attn.shape[1] <= self.max_size ** 2: # avoid memory overhead
|
| 458 |
+
self.step_store[key].append(attn)
|
| 459 |
+
return attn
|
| 460 |
+
|
| 461 |
+
def between_steps(self):
|
| 462 |
+
self.attention_store = self.step_store
|
| 463 |
+
if self.save_global_store:
|
| 464 |
+
with torch.no_grad():
|
| 465 |
+
if len(self.global_store) == 0:
|
| 466 |
+
self.global_store = self.step_store
|
| 467 |
+
else:
|
| 468 |
+
for key in self.global_store:
|
| 469 |
+
for i in range(len(self.global_store[key])):
|
| 470 |
+
self.global_store[key][i] += self.step_store[key][i].detach()
|
| 471 |
+
self.step_store = self.get_empty_store()
|
| 472 |
+
self.step_store = self.get_empty_store()
|
| 473 |
+
|
| 474 |
+
def get_average_attention(self):
|
| 475 |
+
average_attention = self.attention_store
|
| 476 |
+
return average_attention
|
| 477 |
+
|
| 478 |
+
def get_average_global_attention(self):
|
| 479 |
+
average_attention = {key: [item / self.cur_step for item in self.global_store[key]] for key in
|
| 480 |
+
self.attention_store}
|
| 481 |
+
return average_attention
|
| 482 |
+
|
| 483 |
+
def reset(self):
|
| 484 |
+
super(AttentionStore, self).reset()
|
| 485 |
+
self.step_store = self.get_empty_store()
|
| 486 |
+
self.attention_store = {}
|
| 487 |
+
self.global_store = {}
|
| 488 |
+
|
| 489 |
+
def __init__(self, max_size=32, save_global_store=False):
|
| 490 |
+
'''
|
| 491 |
+
Initialize an empty AttentionStore
|
| 492 |
+
:param step_index: used to visualize only a specific step in the diffusion process
|
| 493 |
+
'''
|
| 494 |
+
super(AttentionStore, self).__init__()
|
| 495 |
+
self.save_global_store = save_global_store
|
| 496 |
+
self.max_size = max_size
|
| 497 |
+
self.step_store = self.get_empty_store()
|
| 498 |
+
self.attention_store = {}
|
| 499 |
+
self.global_store = {}
|
| 500 |
+
self.curr_step_index = 0
|
| 501 |
+
|
| 502 |
+
def aggregate_attention(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
|
| 503 |
+
out = []
|
| 504 |
+
attention_maps = attention_store.get_average_attention()
|
| 505 |
+
num_pixels = res ** 2
|
| 506 |
+
for location in from_where:
|
| 507 |
+
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
|
| 508 |
+
if item.shape[1] == num_pixels:
|
| 509 |
+
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
|
| 510 |
+
out.append(cross_maps)
|
| 511 |
+
out = torch.cat(out, dim=0)
|
| 512 |
+
out = out.sum(0) / out.shape[0]
|
| 513 |
+
return out
|
| 514 |
+
|
| 515 |
+
def aggregate_attention1(prompts, attention_store: AttentionStore, res: int, from_where: List[str], is_cross: bool, select: int):
|
| 516 |
+
out = []
|
| 517 |
+
attention_maps = attention_store.get_average_attention()
|
| 518 |
+
num_pixels = res ** 2
|
| 519 |
+
for location in from_where:
|
| 520 |
+
for item in attention_maps[f"{location}_{'cross' if is_cross else 'self'}"]:
|
| 521 |
+
if item.shape[1] == num_pixels:
|
| 522 |
+
cross_maps = item.reshape(len(prompts), -1, res, res, item.shape[-1])[select]
|
| 523 |
+
out.append(cross_maps)
|
| 524 |
+
# out = torch.cat(out, dim=0)
|
| 525 |
+
# out = out.sum(0) / out.shape[0]
|
| 526 |
+
out = out[1]
|
| 527 |
+
out = out.sum(0) / out.shape[0]
|
| 528 |
+
return out
|
| 529 |
+
|
| 530 |
+
|
| 531 |
+
def show_cross_attention(tokenizer, prompts, attention_store: AttentionStore, res: int, from_where: List[str], select: int = 0):
|
| 532 |
+
tokens = tokenizer.encode(prompts[select])
|
| 533 |
+
decoder = tokenizer.decode
|
| 534 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, True, select)
|
| 535 |
+
images = []
|
| 536 |
+
for i in range(len(tokens)):
|
| 537 |
+
image = attention_maps[:, :, i]
|
| 538 |
+
image = 255 * image / image.max()
|
| 539 |
+
image = image.unsqueeze(-1).expand(*image.shape, 3)
|
| 540 |
+
image = image.numpy().astype(np.uint8)
|
| 541 |
+
image = np.array(Image.fromarray(image).resize((256, 256)))
|
| 542 |
+
image = text_under_image(image, decoder(int(tokens[i])))
|
| 543 |
+
images.append(image)
|
| 544 |
+
view_images(np.stack(images, axis=0))
|
| 545 |
+
|
| 546 |
+
|
| 547 |
+
def show_self_attention_comp(attention_store: AttentionStore, res: int, from_where: List[str],
|
| 548 |
+
max_com=10, select: int = 0):
|
| 549 |
+
attention_maps = aggregate_attention(attention_store, res, from_where, False, select).numpy().reshape((res ** 2, res ** 2))
|
| 550 |
+
u, s, vh = np.linalg.svd(attention_maps - np.mean(attention_maps, axis=1, keepdims=True))
|
| 551 |
+
images = []
|
| 552 |
+
for i in range(max_com):
|
| 553 |
+
image = vh[i].reshape(res, res)
|
| 554 |
+
image = image - image.min()
|
| 555 |
+
image = 255 * image / image.max()
|
| 556 |
+
image = np.repeat(np.expand_dims(image, axis=2), 3, axis=2).astype(np.uint8)
|
| 557 |
+
image = Image.fromarray(image).resize((256, 256))
|
| 558 |
+
image = np.array(image)
|
| 559 |
+
images.append(image)
|
| 560 |
+
view_images(np.concatenate(images, axis=1))
|
| 561 |
+
|
| 562 |
+
def text_under_image(image: np.ndarray, text: str, text_color: Tuple[int, int, int] = (0, 0, 0)):
|
| 563 |
+
h, w, c = image.shape
|
| 564 |
+
offset = int(h * .2)
|
| 565 |
+
img = np.ones((h + offset, w, c), dtype=np.uint8) * 255
|
| 566 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
| 567 |
+
# font = ImageFont.truetype("/usr/share/fonts/truetype/noto/NotoMono-Regular.ttf", font_size)
|
| 568 |
+
img[:h] = image
|
| 569 |
+
textsize = cv2.getTextSize(text, font, 1, 2)[0]
|
| 570 |
+
text_x, text_y = (w - textsize[0]) // 2, h + offset - textsize[1] // 2
|
| 571 |
+
cv2.putText(img, text, (text_x, text_y ), font, 1, text_color, 2)
|
| 572 |
+
return img
|
| 573 |
+
|
| 574 |
+
|
| 575 |
+
def view_images(images, num_rows=1, offset_ratio=0.02):
|
| 576 |
+
if type(images) is list:
|
| 577 |
+
num_empty = len(images) % num_rows
|
| 578 |
+
elif images.ndim == 4:
|
| 579 |
+
num_empty = images.shape[0] % num_rows
|
| 580 |
+
else:
|
| 581 |
+
images = [images]
|
| 582 |
+
num_empty = 0
|
| 583 |
+
|
| 584 |
+
empty_images = np.ones(images[0].shape, dtype=np.uint8) * 255
|
| 585 |
+
images = [image.astype(np.uint8) for image in images] + [empty_images] * num_empty
|
| 586 |
+
num_items = len(images)
|
| 587 |
+
|
| 588 |
+
h, w, c = images[0].shape
|
| 589 |
+
offset = int(h * offset_ratio)
|
| 590 |
+
num_cols = num_items // num_rows
|
| 591 |
+
image_ = np.ones((h * num_rows + offset * (num_rows - 1),
|
| 592 |
+
w * num_cols + offset * (num_cols - 1), 3), dtype=np.uint8) * 255
|
| 593 |
+
for i in range(num_rows):
|
| 594 |
+
for j in range(num_cols):
|
| 595 |
+
image_[i * (h + offset): i * (h + offset) + h:, j * (w + offset): j * (w + offset) + w] = images[
|
| 596 |
+
i * num_cols + j]
|
| 597 |
+
|
| 598 |
+
pil_img = Image.fromarray(image_)
|
| 599 |
+
display(pil_img)
|
| 600 |
+
|
| 601 |
+
def self_cross_attn(self_attn, cross_attn):
|
| 602 |
+
cross_attn = cross_attn.squeeze()
|
| 603 |
+
res = self_attn.shape[0]
|
| 604 |
+
assert res == cross_attn.shape[-1]
|
| 605 |
+
# cross attn [res, res] -> [res*res]
|
| 606 |
+
cross_attn_ = cross_attn.reshape([res*res])
|
| 607 |
+
# self_attn [res, res, res*res]
|
| 608 |
+
self_cross_attn = cross_attn_ * self_attn
|
| 609 |
+
self_cross_attn = self_cross_attn.mean(-1).unsqueeze(0).unsqueeze(0)
|
| 610 |
+
return self_cross_attn
|
_utils/config.yaml
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
attn_dist_mode: v0
|
| 2 |
+
attn_positional_bias: rope
|
| 3 |
+
attn_positional_bias_n_spatial: 16
|
| 4 |
+
causal_norm: quiet_softmax
|
| 5 |
+
coord_dim: 2
|
| 6 |
+
d_model: 320
|
| 7 |
+
dropout: 0.0
|
| 8 |
+
feat_dim: 7
|
| 9 |
+
feat_embed_per_dim: 8
|
| 10 |
+
nhead: 4
|
| 11 |
+
num_decoder_layers: 6
|
| 12 |
+
num_encoder_layers: 6
|
| 13 |
+
pos_embed_per_dim: 32
|
| 14 |
+
spatial_pos_cutoff: 256
|
| 15 |
+
window: 4
|
_utils/example_config.yaml
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
batch_size: 1
|
| 2 |
+
crop_size:
|
| 3 |
+
- 256
|
| 4 |
+
- 256
|
| 5 |
+
detection_folders:
|
| 6 |
+
- TRA
|
| 7 |
+
dropout: 0.01
|
| 8 |
+
example_images: False # Slow
|
| 9 |
+
input_train:
|
| 10 |
+
- data/ctc/Fluo-N2DL-HeLa/01
|
| 11 |
+
input_val:
|
| 12 |
+
- data/ctc/Fluo-N2DL-HeLa/02
|
| 13 |
+
max_tokens: 2048
|
| 14 |
+
name: example
|
| 15 |
+
ndim: 2
|
| 16 |
+
num_decoder_layers: 5
|
| 17 |
+
num_encoder_layers: 5
|
| 18 |
+
outdir: runs
|
| 19 |
+
distributed: False
|
| 20 |
+
window: 4
|
_utils/load_models.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from config import RunConfig
|
| 2 |
+
import torch
|
| 3 |
+
from diffusers.pipelines.stable_diffusion import StableDiffusionPipeline
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
|
| 6 |
+
def load_stable_diffusion_model(config: RunConfig):
|
| 7 |
+
device = torch.device('cpu')
|
| 8 |
+
|
| 9 |
+
if config.sd_2_1:
|
| 10 |
+
stable_diffusion_version = "stabilityai/stable-diffusion-2-1-base"
|
| 11 |
+
else:
|
| 12 |
+
stable_diffusion_version = "CompVis/stable-diffusion-v1-4"
|
| 13 |
+
# stable = StableCountingPipeline.from_pretrained(stable_diffusion_version).to(device)
|
| 14 |
+
stable = StableDiffusionPipeline.from_pretrained(stable_diffusion_version).to(device)
|
| 15 |
+
return stable
|
| 16 |
+
|
_utils/load_track_data.py
ADDED
|
@@ -0,0 +1,118 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from glob import glob
|
| 3 |
+
from pathlib import Path
|
| 4 |
+
from natsort import natsorted
|
| 5 |
+
from PIL import Image
|
| 6 |
+
import numpy as np
|
| 7 |
+
import tifffile
|
| 8 |
+
import skimage.io as io
|
| 9 |
+
import torchvision.transforms as T
|
| 10 |
+
import cv2
|
| 11 |
+
from tqdm import tqdm
|
| 12 |
+
from models.tra_post_model.trackastra.utils import normalize_01, normalize
|
| 13 |
+
IMG_SIZE = 512
|
| 14 |
+
|
| 15 |
+
def _load_tiffs(folder: Path, dtype=None):
|
| 16 |
+
"""Load a sequence of tiff files from a folder into a 3D numpy array."""
|
| 17 |
+
images = glob(str(folder / "*.tif"))
|
| 18 |
+
test_data = tifffile.imread(images[0])
|
| 19 |
+
if len(test_data.shape) == 3:
|
| 20 |
+
turn_gray = True
|
| 21 |
+
else:
|
| 22 |
+
turn_gray = False
|
| 23 |
+
end_frame = len(images)
|
| 24 |
+
if not turn_gray:
|
| 25 |
+
x = np.stack([
|
| 26 |
+
tifffile.imread(f).astype(dtype)
|
| 27 |
+
for f in tqdm(
|
| 28 |
+
sorted(folder.glob("*.tif"))[0 : end_frame : 1],
|
| 29 |
+
leave=False,
|
| 30 |
+
desc=f"Loading [0:{end_frame}]",
|
| 31 |
+
)
|
| 32 |
+
])
|
| 33 |
+
else:
|
| 34 |
+
x = []
|
| 35 |
+
for f in tqdm(
|
| 36 |
+
sorted(folder.glob("*.tif"))[0 : end_frame : 1],
|
| 37 |
+
leave=False,
|
| 38 |
+
desc=f"Loading [0:{end_frame}]",
|
| 39 |
+
):
|
| 40 |
+
img = tifffile.imread(f).astype(dtype)
|
| 41 |
+
if img.ndim == 3:
|
| 42 |
+
if img.shape[-1] > 3:
|
| 43 |
+
img = img[..., :3]
|
| 44 |
+
img = (0.299 * img[..., 0] + 0.587 * img[..., 1] + 0.114 * img[..., 2])
|
| 45 |
+
x.append(img)
|
| 46 |
+
x = np.stack(x)
|
| 47 |
+
return x
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def load_track_images(file_dir):
|
| 51 |
+
|
| 52 |
+
# suffix_ = [".png", ".tif", ".tiff", ".jpg"]
|
| 53 |
+
def find_tif_dir(root_dir):
|
| 54 |
+
"""递归查找.tif 文件"""
|
| 55 |
+
tif_files = []
|
| 56 |
+
for dirpath, _, filenames in os.walk(root_dir):
|
| 57 |
+
if '__MACOSX' in dirpath:
|
| 58 |
+
continue
|
| 59 |
+
for f in filenames:
|
| 60 |
+
if f.lower().endswith('.tif'):
|
| 61 |
+
tif_files.append(os.path.join(dirpath, f))
|
| 62 |
+
return tif_files
|
| 63 |
+
|
| 64 |
+
tif_dir = find_tif_dir(file_dir)
|
| 65 |
+
print(f"Found {len(tif_dir)} tif images in {file_dir}")
|
| 66 |
+
print(f"First 5 tif images: {tif_dir[:5]}")
|
| 67 |
+
assert len(tif_dir) > 0, f"No tif images found in {file_dir}"
|
| 68 |
+
images = natsorted(tif_dir)
|
| 69 |
+
imgs = []
|
| 70 |
+
imgs_raw = []
|
| 71 |
+
images_stable = []
|
| 72 |
+
# load images for seg and track
|
| 73 |
+
for img_path in tqdm(images, desc="Loading images"):
|
| 74 |
+
img = tifffile.imread(img_path)
|
| 75 |
+
img_raw = io.imread(img_path)
|
| 76 |
+
|
| 77 |
+
if img.dtype == 'uint16':
|
| 78 |
+
img = ((img - img.min()) / (img.max() - img.min() + 1e-6) * 255).astype(np.uint8)
|
| 79 |
+
img = np.stack([img] * 3, axis=-1)
|
| 80 |
+
w, h = img.shape[1], img.shape[0]
|
| 81 |
+
else:
|
| 82 |
+
img = Image.open(img_path).convert("RGB")
|
| 83 |
+
w, h = img.size
|
| 84 |
+
|
| 85 |
+
img = T.Compose([
|
| 86 |
+
T.ToTensor(),
|
| 87 |
+
T.Resize((IMG_SIZE, IMG_SIZE)),
|
| 88 |
+
])(img)
|
| 89 |
+
|
| 90 |
+
image_stable = img - 0.5
|
| 91 |
+
img = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(img)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
imgs.append(img)
|
| 95 |
+
imgs_raw.append(img_raw)
|
| 96 |
+
images_stable.append(image_stable)
|
| 97 |
+
|
| 98 |
+
height = h
|
| 99 |
+
width = w
|
| 100 |
+
imgs = np.stack(imgs, axis=0)
|
| 101 |
+
imgs_raw = np.stack(imgs_raw, axis=0)
|
| 102 |
+
images_stable = np.stack(images_stable, axis=0)
|
| 103 |
+
|
| 104 |
+
# track data
|
| 105 |
+
imgs_ = _load_tiffs(Path(file_dir), dtype=np.float32)
|
| 106 |
+
imgs_01 = np.stack([
|
| 107 |
+
normalize_01(_x) for _x in tqdm(imgs_, desc="Normalizing", leave=False)
|
| 108 |
+
])
|
| 109 |
+
imgs_ = np.stack([
|
| 110 |
+
normalize(_x) for _x in tqdm(imgs_, desc="Normalizing", leave=False)
|
| 111 |
+
])
|
| 112 |
+
|
| 113 |
+
return imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width
|
| 114 |
+
|
| 115 |
+
if __name__ == "__main__":
|
| 116 |
+
file_dir = "data/2D+Time/DIC-C2DH-HeLa/train/DIC-C2DH-HeLa/02"
|
| 117 |
+
imgs, imgs_raw, images_stable, imgs_, imgs_01, height, width = load_track_images(file_dir)
|
| 118 |
+
print(imgs.shape, imgs_raw.shape, images_stable.shape, imgs_.shape, imgs_01.shape, height, width)
|
_utils/misc_helper.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import logging
|
| 2 |
+
import os
|
| 3 |
+
import random
|
| 4 |
+
import shutil
|
| 5 |
+
from collections.abc import Mapping
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
import numpy as np
|
| 9 |
+
import torch
|
| 10 |
+
import torch.distributed as dist
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def basicConfig(*args, **kwargs):
|
| 14 |
+
return
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
# To prevent duplicate logs, we mask this baseConfig setting
|
| 18 |
+
logging.basicConfig = basicConfig
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def create_logger(name, log_file, level=logging.INFO):
|
| 22 |
+
log = logging.getLogger(name)
|
| 23 |
+
formatter = logging.Formatter(
|
| 24 |
+
"[%(asctime)s][%(filename)15s][line:%(lineno)4d][%(levelname)8s] %(message)s"
|
| 25 |
+
)
|
| 26 |
+
fh = logging.FileHandler(log_file)
|
| 27 |
+
fh.setFormatter(formatter)
|
| 28 |
+
sh = logging.StreamHandler()
|
| 29 |
+
sh.setFormatter(formatter)
|
| 30 |
+
log.setLevel(level)
|
| 31 |
+
log.addHandler(fh)
|
| 32 |
+
log.addHandler(sh)
|
| 33 |
+
return log
|
| 34 |
+
|
| 35 |
+
def get_current_time():
|
| 36 |
+
current_time = datetime.now().strftime("%Y%m%d_%H%M%S")
|
| 37 |
+
return current_time
|
_utils/seg_eval.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def iou_torch(inst1, inst2):
|
| 5 |
+
inter = torch.logical_and(inst1, inst2).sum().float()
|
| 6 |
+
union = torch.logical_or(inst1, inst2).sum().float()
|
| 7 |
+
if union == 0:
|
| 8 |
+
return torch.tensor(float('nan'))
|
| 9 |
+
return inter / union
|
| 10 |
+
|
| 11 |
+
def get_instances_torch(mask):
|
| 12 |
+
# 返回所有非背景的 instance mask(布尔型)
|
| 13 |
+
ids = torch.unique(mask)
|
| 14 |
+
return [(mask == i) for i in ids if i != 0]
|
| 15 |
+
|
| 16 |
+
def compute_instance_miou(pred_mask, gt_mask):
|
| 17 |
+
# pred_mask 和 gt_mask 都是 torch.Tensor, shape [H, W], 整数类型
|
| 18 |
+
pred_instances = get_instances_torch(pred_mask)
|
| 19 |
+
gt_instances = get_instances_torch(gt_mask)
|
| 20 |
+
|
| 21 |
+
ious = []
|
| 22 |
+
for gt in gt_instances:
|
| 23 |
+
best_iou = torch.tensor(0.0).to(pred_mask.device)
|
| 24 |
+
for pred in pred_instances:
|
| 25 |
+
i = iou_torch(pred, gt)
|
| 26 |
+
if i > best_iou:
|
| 27 |
+
best_iou = i
|
| 28 |
+
ious.append(best_iou)
|
| 29 |
+
|
| 30 |
+
# 处理空情况
|
| 31 |
+
if len(ious) == 0:
|
| 32 |
+
return torch.tensor(float('nan'))
|
| 33 |
+
return torch.nanmean(torch.stack(ious))
|
| 34 |
+
|
| 35 |
+
from torch import Tensor
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
|
| 39 |
+
# Average of Dice coefficient for all batches, or for a single mask
|
| 40 |
+
assert input.size() == target.size()
|
| 41 |
+
assert input.dim() == 3 or not reduce_batch_first
|
| 42 |
+
|
| 43 |
+
sum_dim = (-1, -2) if input.dim() == 2 or not reduce_batch_first else (-1, -2, -3)
|
| 44 |
+
|
| 45 |
+
inter = 2 * (input * target).sum(dim=sum_dim)
|
| 46 |
+
sets_sum = input.sum(dim=sum_dim) + target.sum(dim=sum_dim)
|
| 47 |
+
sets_sum = torch.where(sets_sum == 0, inter, sets_sum)
|
| 48 |
+
|
| 49 |
+
dice = (inter + epsilon) / (sets_sum + epsilon)
|
| 50 |
+
return dice.mean()
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def multiclass_dice_coeff(input: Tensor, target: Tensor, reduce_batch_first: bool = False, epsilon: float = 1e-6):
|
| 54 |
+
# Average of Dice coefficient for all classes
|
| 55 |
+
return dice_coeff(input.flatten(0, 1), target.flatten(0, 1), reduce_batch_first, epsilon)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def dice_loss(input: Tensor, target: Tensor, multiclass: bool = False):
|
| 59 |
+
# Dice loss (objective to minimize) between 0 and 1
|
| 60 |
+
fn = multiclass_dice_coeff if multiclass else dice_coeff
|
| 61 |
+
return 1 - fn(input, target, reduce_batch_first=True)
|
_utils/track_args.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import configargparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def parse_train_args():
|
| 5 |
+
parser = configargparse.ArgumentParser(
|
| 6 |
+
formatter_class=configargparse.ArgumentDefaultsHelpFormatter,
|
| 7 |
+
config_file_parser_class=configargparse.YAMLConfigFileParser,
|
| 8 |
+
allow_abbrev=False,
|
| 9 |
+
)
|
| 10 |
+
parser.add_argument(
|
| 11 |
+
"-c",
|
| 12 |
+
"--config",
|
| 13 |
+
default="_utils/example_config.yaml",
|
| 14 |
+
is_config_file=True,
|
| 15 |
+
help="config file path",
|
| 16 |
+
)
|
| 17 |
+
parser.add_argument("--device", type=str, choices=["cuda", "cpu"], default="cuda")
|
| 18 |
+
parser.add_argument("-o", "--outdir", type=str, default="runs")
|
| 19 |
+
parser.add_argument("--name", type=str, help="Name to append to timestamp")
|
| 20 |
+
parser.add_argument("--timestamp", type=bool, default=True)
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"-m",
|
| 23 |
+
"--model",
|
| 24 |
+
type=str,
|
| 25 |
+
default="",
|
| 26 |
+
help="load this model at start (e.g. to continue training)",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--ndim", type=int, default=2, help="number of spatial dimensions"
|
| 30 |
+
)
|
| 31 |
+
parser.add_argument("-d", "--d_model", type=int, default=256)
|
| 32 |
+
parser.add_argument("-w", "--window", type=int, default=10)
|
| 33 |
+
parser.add_argument("--epochs", type=int, default=100)
|
| 34 |
+
parser.add_argument("--warmup_epochs", type=int, default=10)
|
| 35 |
+
parser.add_argument(
|
| 36 |
+
"--detection_folders",
|
| 37 |
+
type=str,
|
| 38 |
+
nargs="+",
|
| 39 |
+
default=["TRA"],
|
| 40 |
+
help=(
|
| 41 |
+
"Subfolders to search for detections. Defaults to `TRA`, which corresponds"
|
| 42 |
+
" to using only the GT."
|
| 43 |
+
),
|
| 44 |
+
)
|
| 45 |
+
parser.add_argument("--downscale_temporal", type=int, default=1)
|
| 46 |
+
parser.add_argument("--downscale_spatial", type=int, default=1)
|
| 47 |
+
parser.add_argument("--spatial_pos_cutoff", type=int, default=256)
|
| 48 |
+
parser.add_argument("--from_subfolder", action="store_true")
|
| 49 |
+
# parser.add_argument("--train_samples", type=int, default=50000)
|
| 50 |
+
parser.add_argument("--num_encoder_layers", type=int, default=6)
|
| 51 |
+
parser.add_argument("--num_decoder_layers", type=int, default=6)
|
| 52 |
+
parser.add_argument("--pos_embed_per_dim", type=int, default=32)
|
| 53 |
+
parser.add_argument("--feat_embed_per_dim", type=int, default=8)
|
| 54 |
+
parser.add_argument("--dropout", type=float, default=0.00)
|
| 55 |
+
parser.add_argument("--num_workers", type=int, default=4)
|
| 56 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
| 57 |
+
parser.add_argument("--max_tokens", type=int, default=None)
|
| 58 |
+
parser.add_argument("--delta_cutoff", type=int, default=2)
|
| 59 |
+
parser.add_argument("--lr", type=float, default=1e-4)
|
| 60 |
+
parser.add_argument(
|
| 61 |
+
"--attn_positional_bias",
|
| 62 |
+
type=str,
|
| 63 |
+
choices=["rope", "bias", "none"],
|
| 64 |
+
default="rope",
|
| 65 |
+
)
|
| 66 |
+
parser.add_argument("--attn_positional_bias_n_spatial", type=int, default=16)
|
| 67 |
+
parser.add_argument("--attn_dist_mode", default="v0")
|
| 68 |
+
parser.add_argument("--mixedp", type=bool, default=True)
|
| 69 |
+
parser.add_argument("--dry", action="store_true")
|
| 70 |
+
parser.add_argument("--profile", action="store_true")
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--features",
|
| 73 |
+
type=str,
|
| 74 |
+
choices=[
|
| 75 |
+
"none",
|
| 76 |
+
"regionprops",
|
| 77 |
+
"regionprops2",
|
| 78 |
+
"patch",
|
| 79 |
+
"patch_regionprops",
|
| 80 |
+
"wrfeat",
|
| 81 |
+
],
|
| 82 |
+
default="wrfeat",
|
| 83 |
+
)
|
| 84 |
+
parser.add_argument(
|
| 85 |
+
"--causal_norm",
|
| 86 |
+
type=str,
|
| 87 |
+
choices=["none", "linear", "softmax", "quiet_softmax"],
|
| 88 |
+
default="quiet_softmax",
|
| 89 |
+
)
|
| 90 |
+
parser.add_argument("--div_upweight", type=float, default=2)
|
| 91 |
+
|
| 92 |
+
parser.add_argument("--augment", type=int, default=3)
|
| 93 |
+
parser.add_argument("--tracking_frequency", type=int, default=-1)
|
| 94 |
+
|
| 95 |
+
parser.add_argument("--sanity_dist", action="store_true")
|
| 96 |
+
parser.add_argument("--preallocate", type=bool, default=False)
|
| 97 |
+
parser.add_argument("--only_prechecks", action="store_true")
|
| 98 |
+
parser.add_argument(
|
| 99 |
+
"--compress", type=bool, default=True, help="compress dataset"
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
parser.add_argument("--seed", type=int, default=None)
|
| 104 |
+
parser.add_argument(
|
| 105 |
+
"--logger",
|
| 106 |
+
type=str,
|
| 107 |
+
default="tensorboard",
|
| 108 |
+
choices=["tensorboard", "wandb", "none"],
|
| 109 |
+
)
|
| 110 |
+
parser.add_argument("--wandb_project", type=str, default="trackastra")
|
| 111 |
+
parser.add_argument(
|
| 112 |
+
"--crop_size",
|
| 113 |
+
type=int,
|
| 114 |
+
# required=True,
|
| 115 |
+
nargs="+",
|
| 116 |
+
default=None,
|
| 117 |
+
help="random crop size for augmentation",
|
| 118 |
+
)
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--weight_by_ndivs",
|
| 121 |
+
type=bool,
|
| 122 |
+
default=True,
|
| 123 |
+
help="Oversample windows that contain divisions",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--weight_by_dataset",
|
| 127 |
+
type=bool,
|
| 128 |
+
default=False,
|
| 129 |
+
help=(
|
| 130 |
+
"Inversely weight datasets by number of samples (to counter dataset size"
|
| 131 |
+
" imbalance)"
|
| 132 |
+
),
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
args, unknown_args = parser.parse_known_args()
|
| 136 |
+
|
| 137 |
+
# # Hack to allow for --input_test
|
| 138 |
+
# allowed_unknown = ["input_test"]
|
| 139 |
+
# if not set(a.split("=")[0].strip("-") for a in unknown_args).issubset(
|
| 140 |
+
# set(allowed_unknown)
|
| 141 |
+
# ):
|
| 142 |
+
# raise ValueError(f"Unknown args: {unknown_args}")
|
| 143 |
+
|
| 144 |
+
# pprint(vars(args))
|
| 145 |
+
|
| 146 |
+
# for backward compatibility
|
| 147 |
+
# if args.attn_positional_bias == "True":
|
| 148 |
+
# args.attn_positional_bias = "bias"
|
| 149 |
+
# elif args.attn_positional_bias == "False":
|
| 150 |
+
# args.attn_positional_bias = False
|
| 151 |
+
|
| 152 |
+
# if args.train_samples == 0:
|
| 153 |
+
# raise NotImplementedError(
|
| 154 |
+
# "--train_samples must be > 0, full dataset pass not supported."
|
| 155 |
+
# )
|
| 156 |
+
|
| 157 |
+
return args
|
app.py
ADDED
|
@@ -0,0 +1,1638 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from gradio_bbox_annotator import BBoxAnnotator
|
| 3 |
+
from PIL import Image
|
| 4 |
+
import numpy as np
|
| 5 |
+
import torch
|
| 6 |
+
import os
|
| 7 |
+
import shutil
|
| 8 |
+
import time
|
| 9 |
+
import json
|
| 10 |
+
import uuid
|
| 11 |
+
from pathlib import Path
|
| 12 |
+
import tempfile
|
| 13 |
+
import zipfile
|
| 14 |
+
from skimage import measure
|
| 15 |
+
from matplotlib import cm
|
| 16 |
+
from glob import glob
|
| 17 |
+
from natsort import natsorted
|
| 18 |
+
from huggingface_hub import HfApi, upload_file
|
| 19 |
+
# import spaces
|
| 20 |
+
|
| 21 |
+
# ===== 导入三个推理模块 =====
|
| 22 |
+
from inference_seg import load_model as load_seg_model, run as run_seg
|
| 23 |
+
from inference_count import load_model as load_count_model, run as run_count
|
| 24 |
+
from inference_track import load_model as load_track_model, run as run_track
|
| 25 |
+
|
| 26 |
+
HF_TOKEN = os.getenv("HF_TOKEN")
|
| 27 |
+
DATASET_REPO = "phoebe777777/celltool_feedback"
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# ===== 清理缓存目录 =====
|
| 31 |
+
print("===== clearing cache =====")
|
| 32 |
+
# cache_path = os.path.expanduser("~/.cache/")
|
| 33 |
+
cache_path = os.path.expanduser("~/.cache/huggingface/gradio")
|
| 34 |
+
if os.path.exists(cache_path):
|
| 35 |
+
try:
|
| 36 |
+
shutil.rmtree(cache_path)
|
| 37 |
+
# print("✅ Deleted ~/.cache/")
|
| 38 |
+
print("✅ Deleted ~/.cache/huggingface/gradio")
|
| 39 |
+
except:
|
| 40 |
+
pass
|
| 41 |
+
|
| 42 |
+
# ===== 全局模型变量 =====
|
| 43 |
+
SEG_MODEL = None
|
| 44 |
+
SEG_DEVICE = torch.device("cpu")
|
| 45 |
+
|
| 46 |
+
COUNT_MODEL = None
|
| 47 |
+
COUNT_DEVICE = torch.device("cpu")
|
| 48 |
+
|
| 49 |
+
TRACK_MODEL = None
|
| 50 |
+
TRACK_DEVICE = torch.device("cpu")
|
| 51 |
+
|
| 52 |
+
def load_all_models():
|
| 53 |
+
"""启动时加载所有模型"""
|
| 54 |
+
global SEG_MODEL, SEG_DEVICE
|
| 55 |
+
global COUNT_MODEL, COUNT_DEVICE
|
| 56 |
+
global TRACK_MODEL, TRACK_DEVICE
|
| 57 |
+
|
| 58 |
+
print("\n" + "="*60)
|
| 59 |
+
print("📦 Loading Segmentation Model")
|
| 60 |
+
print("="*60)
|
| 61 |
+
SEG_MODEL, SEG_DEVICE = load_seg_model(use_box=False)
|
| 62 |
+
|
| 63 |
+
print("\n" + "="*60)
|
| 64 |
+
print("📦 Loading Counting Model")
|
| 65 |
+
print("="*60)
|
| 66 |
+
COUNT_MODEL, COUNT_DEVICE = load_count_model(use_box=False)
|
| 67 |
+
|
| 68 |
+
print("\n" + "="*60)
|
| 69 |
+
print("📦 Loading Tracking Model")
|
| 70 |
+
print("="*60)
|
| 71 |
+
TRACK_MODEL, TRACK_DEVICE = load_track_model(use_box=False)
|
| 72 |
+
|
| 73 |
+
print("\n" + "="*60)
|
| 74 |
+
print("✅ All Models Loaded Successfully")
|
| 75 |
+
print("="*60)
|
| 76 |
+
|
| 77 |
+
load_all_models()
|
| 78 |
+
|
| 79 |
+
# ===== 保存用户反馈 =====
|
| 80 |
+
DATASET_DIR = Path("solver_cache")
|
| 81 |
+
DATASET_DIR.mkdir(parents=True, exist_ok=True)
|
| 82 |
+
|
| 83 |
+
def save_feedback_to_hf(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
|
| 84 |
+
"""保存反馈到 Hugging Face Dataset"""
|
| 85 |
+
|
| 86 |
+
# 如果没有 token,回退到本地存储
|
| 87 |
+
if not HF_TOKEN:
|
| 88 |
+
print("⚠️ No HF_TOKEN found, using local storage")
|
| 89 |
+
save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
|
| 90 |
+
return
|
| 91 |
+
|
| 92 |
+
feedback_data = {
|
| 93 |
+
"query_id": query_id,
|
| 94 |
+
"feedback_type": feedback_type,
|
| 95 |
+
"feedback_text": feedback_text,
|
| 96 |
+
"image_path": img_path,
|
| 97 |
+
"bboxes": str(bboxes), # 转为字符串
|
| 98 |
+
"datetime": time.strftime("%Y-%m-%d %H:%M:%S"),
|
| 99 |
+
"timestamp": time.time()
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
+
try:
|
| 103 |
+
api = HfApi()
|
| 104 |
+
|
| 105 |
+
# 创建临时文件
|
| 106 |
+
filename = f"feedback_{query_id}_{int(time.time())}.json"
|
| 107 |
+
|
| 108 |
+
with open(filename, 'w', encoding='utf-8') as f:
|
| 109 |
+
json.dump(feedback_data, f, indent=2, ensure_ascii=False)
|
| 110 |
+
|
| 111 |
+
# 上传到 dataset
|
| 112 |
+
api.upload_file(
|
| 113 |
+
path_or_fileobj=filename,
|
| 114 |
+
path_in_repo=f"data/{filename}",
|
| 115 |
+
repo_id=DATASET_REPO,
|
| 116 |
+
repo_type="dataset",
|
| 117 |
+
token=HF_TOKEN
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
# 清理本地文件
|
| 121 |
+
os.remove(filename)
|
| 122 |
+
|
| 123 |
+
print(f"✅ Feedback saved to HF Dataset: {DATASET_REPO}")
|
| 124 |
+
|
| 125 |
+
except Exception as e:
|
| 126 |
+
print(f"⚠️ Failed to save to HF Dataset: {e}")
|
| 127 |
+
# 回退到本地存储
|
| 128 |
+
save_feedback(query_id, feedback_type, feedback_text, img_path, bboxes)
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
def save_feedback(query_id, feedback_type, feedback_text=None, img_path=None, bboxes=None):
|
| 132 |
+
"""保存用户反馈到JSON文件"""
|
| 133 |
+
feedback_data = {
|
| 134 |
+
"query_id": query_id,
|
| 135 |
+
"feedback_type": feedback_type,
|
| 136 |
+
"feedback_text": feedback_text,
|
| 137 |
+
"image": img_path,
|
| 138 |
+
"bboxes": bboxes,
|
| 139 |
+
"datetime": time.strftime("%Y%m%d_%H%M%S")
|
| 140 |
+
}
|
| 141 |
+
feedback_file = DATASET_DIR / query_id / "feedback.json"
|
| 142 |
+
feedback_file.parent.mkdir(parents=True, exist_ok=True)
|
| 143 |
+
|
| 144 |
+
if feedback_file.exists():
|
| 145 |
+
with feedback_file.open("r") as f:
|
| 146 |
+
existing = json.load(f)
|
| 147 |
+
if not isinstance(existing, list):
|
| 148 |
+
existing = [existing]
|
| 149 |
+
existing.append(feedback_data)
|
| 150 |
+
feedback_data = existing
|
| 151 |
+
else:
|
| 152 |
+
feedback_data = [feedback_data]
|
| 153 |
+
|
| 154 |
+
with feedback_file.open("w") as f:
|
| 155 |
+
json.dump(feedback_data, f, indent=4, ensure_ascii=False)
|
| 156 |
+
|
| 157 |
+
# ===== 辅助函数 =====
|
| 158 |
+
def parse_first_bbox(bboxes):
|
| 159 |
+
"""解析第一个边界框"""
|
| 160 |
+
if not bboxes:
|
| 161 |
+
return None
|
| 162 |
+
b = bboxes[0]
|
| 163 |
+
if isinstance(b, dict):
|
| 164 |
+
x, y = float(b.get("x", 0)), float(b.get("y", 0))
|
| 165 |
+
w, h = float(b.get("width", 0)), float(b.get("height", 0))
|
| 166 |
+
return x, y, x + w, y + h
|
| 167 |
+
if isinstance(b, (list, tuple)) and len(b) >= 4:
|
| 168 |
+
return float(b[0]), float(b[1]), float(b[2]), float(b[3])
|
| 169 |
+
return None
|
| 170 |
+
|
| 171 |
+
def parse_bboxes(bboxes):
|
| 172 |
+
"""解析所有边界框"""
|
| 173 |
+
if not bboxes:
|
| 174 |
+
return None
|
| 175 |
+
|
| 176 |
+
result = []
|
| 177 |
+
for b in bboxes:
|
| 178 |
+
if isinstance(b, dict):
|
| 179 |
+
x, y = float(b.get("x", 0)), float(b.get("y", 0))
|
| 180 |
+
w, h = float(b.get("width", 0)), float(b.get("height", 0))
|
| 181 |
+
result.append([x, y, x + w, y + h])
|
| 182 |
+
elif isinstance(b, (list, tuple)) and len(b) >= 4:
|
| 183 |
+
result.append([float(b[0]), float(b[1]), float(b[2]), float(b[3])])
|
| 184 |
+
|
| 185 |
+
return result
|
| 186 |
+
|
| 187 |
+
def colorize_mask(mask: np.ndarray, num_colors: int = 512) -> np.ndarray:
|
| 188 |
+
"""将实例掩码转换为彩色图像"""
|
| 189 |
+
def hsv_to_rgb(h, s, v):
|
| 190 |
+
i = int(h * 6.0)
|
| 191 |
+
f = h * 6.0 - i
|
| 192 |
+
i = i % 6
|
| 193 |
+
p = v * (1 - s)
|
| 194 |
+
q = v * (1 - f * s)
|
| 195 |
+
t = v * (1 - (1 - f) * s)
|
| 196 |
+
if i == 0: r, g, b = v, t, p
|
| 197 |
+
elif i == 1: r, g, b = q, v, p
|
| 198 |
+
elif i == 2: r, g, b = p, v, t
|
| 199 |
+
elif i == 3: r, g, b = p, q, v
|
| 200 |
+
elif i == 4: r, g, b = t, p, v
|
| 201 |
+
else: r, g, b = v, p, q
|
| 202 |
+
return int(r * 255), int(g * 255), int(b * 255)
|
| 203 |
+
|
| 204 |
+
palette = [(0, 0, 0)]
|
| 205 |
+
for i in range(1, num_colors):
|
| 206 |
+
h = (i % num_colors) / float(num_colors)
|
| 207 |
+
palette.append(hsv_to_rgb(h, 1.0, 0.95))
|
| 208 |
+
|
| 209 |
+
palette_arr = np.array(palette, dtype=np.uint8)
|
| 210 |
+
color_idx = mask % num_colors
|
| 211 |
+
return palette_arr[color_idx]
|
| 212 |
+
|
| 213 |
+
# ===== 分割功能 =====
|
| 214 |
+
# @spaces.GPU
|
| 215 |
+
def segment_with_choice(use_box_choice, annot_value):
|
| 216 |
+
"""分割主函数 - 每个实例不同颜色+轮廓"""
|
| 217 |
+
if annot_value is None or len(annot_value) < 1:
|
| 218 |
+
print("❌ No annotation input")
|
| 219 |
+
return None, None
|
| 220 |
+
|
| 221 |
+
img_path = annot_value[0]
|
| 222 |
+
bboxes = annot_value[1] if len(annot_value) > 1 else []
|
| 223 |
+
|
| 224 |
+
print(f"🖼️ Image path: {img_path}")
|
| 225 |
+
box_array = None
|
| 226 |
+
if use_box_choice == "Yes" and bboxes:
|
| 227 |
+
# box = parse_first_bbox(bboxes)
|
| 228 |
+
# if box:
|
| 229 |
+
# xmin, ymin, xmax, ymax = map(int, box)
|
| 230 |
+
# box_array = [[xmin, ymin, xmax, ymax]]
|
| 231 |
+
# print(f"📦 Using bounding box: {box_array}")
|
| 232 |
+
box = parse_bboxes(bboxes)
|
| 233 |
+
if box:
|
| 234 |
+
box_array = box
|
| 235 |
+
print(f"📦 Using bounding boxes: {box_array}")
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
# 运行分割模型
|
| 239 |
+
try:
|
| 240 |
+
mask = run_seg(SEG_MODEL, img_path, box=box_array, device=SEG_DEVICE)
|
| 241 |
+
print("📏 mask shape:", mask.shape, "dtype:", mask.dtype, "unique:", np.unique(mask))
|
| 242 |
+
except Exception as e:
|
| 243 |
+
print(f"❌ Inference failed: {str(e)}")
|
| 244 |
+
return None, None
|
| 245 |
+
|
| 246 |
+
# 保存原始mask为TIF文件
|
| 247 |
+
temp_mask_file = tempfile.NamedTemporaryFile(delete=False, suffix=".tif")
|
| 248 |
+
mask_img = Image.fromarray(mask.astype(np.uint16))
|
| 249 |
+
mask_img.save(temp_mask_file.name)
|
| 250 |
+
print(f"💾 Original mask saved to: {temp_mask_file.name}")
|
| 251 |
+
|
| 252 |
+
# 读取原图
|
| 253 |
+
try:
|
| 254 |
+
img = Image.open(img_path)
|
| 255 |
+
print("📷 Image mode:", img.mode, "size:", img.size)
|
| 256 |
+
except Exception as e:
|
| 257 |
+
print(f"❌ Failed to open image: {e}")
|
| 258 |
+
return None, None
|
| 259 |
+
|
| 260 |
+
try:
|
| 261 |
+
img_rgb = img.convert("RGB").resize(mask.shape[::-1], resample=Image.BILINEAR)
|
| 262 |
+
img_np = np.array(img_rgb, dtype=np.float32)
|
| 263 |
+
if img_np.max() > 1.5:
|
| 264 |
+
img_np = img_np / 255.0
|
| 265 |
+
except Exception as e:
|
| 266 |
+
print(f"❌ Error in image conversion/resizing: {e}")
|
| 267 |
+
return None, None
|
| 268 |
+
|
| 269 |
+
mask_np = np.array(mask)
|
| 270 |
+
inst_mask = mask_np.astype(np.int32)
|
| 271 |
+
unique_ids = np.unique(inst_mask)
|
| 272 |
+
num_instances = len(unique_ids[unique_ids != 0])
|
| 273 |
+
print(f"✅ Instance IDs found: {unique_ids}, Total instances: {num_instances}")
|
| 274 |
+
|
| 275 |
+
if num_instances == 0:
|
| 276 |
+
print("⚠️ No instance found, returning dummy red image")
|
| 277 |
+
return Image.new("RGB", mask.shape[::-1], (255, 0, 0)), None
|
| 278 |
+
|
| 279 |
+
# ==== Color Overlay (每个实例一个颜色) ====
|
| 280 |
+
overlay = img_np.copy()
|
| 281 |
+
alpha = 0.5
|
| 282 |
+
# cmap = cm.get_cmap("hsv", num_instances + 1)
|
| 283 |
+
|
| 284 |
+
for inst_id in np.unique(inst_mask):
|
| 285 |
+
if inst_id == 0:
|
| 286 |
+
continue
|
| 287 |
+
binary_mask = (inst_mask == inst_id).astype(np.uint8)
|
| 288 |
+
# color = np.array(cmap(inst_id / (num_instances + 1))[:3]) # RGB only, ignore alpha
|
| 289 |
+
color = get_well_spaced_color(inst_id)
|
| 290 |
+
overlay[binary_mask == 1] = (1 - alpha) * overlay[binary_mask == 1] + alpha * color
|
| 291 |
+
|
| 292 |
+
# 绘制轮廓
|
| 293 |
+
contours = measure.find_contours(binary_mask, 0.5)
|
| 294 |
+
for contour in contours:
|
| 295 |
+
contour = contour.astype(np.int32)
|
| 296 |
+
# 确保坐标在范围内
|
| 297 |
+
valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1)
|
| 298 |
+
valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1)
|
| 299 |
+
overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # 黄色轮廓
|
| 300 |
+
|
| 301 |
+
overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
|
| 302 |
+
|
| 303 |
+
return Image.fromarray(overlay), temp_mask_file.name
|
| 304 |
+
|
| 305 |
+
# ===== 计数功能 =====
|
| 306 |
+
# @spaces.GPU
|
| 307 |
+
def count_cells_handler(use_box_choice, annot_value):
|
| 308 |
+
"""Counting handler - supports bounding box, returns only density map"""
|
| 309 |
+
if annot_value is None or len(annot_value) < 1:
|
| 310 |
+
return None, "⚠️ Please provide an image."
|
| 311 |
+
|
| 312 |
+
image_path = annot_value[0]
|
| 313 |
+
bboxes = annot_value[1] if len(annot_value) > 1 else []
|
| 314 |
+
|
| 315 |
+
print(f"🖼️ Image path: {image_path}")
|
| 316 |
+
box_array = None
|
| 317 |
+
if use_box_choice == "Yes" and bboxes:
|
| 318 |
+
# box = parse_first_bbox(bboxes)
|
| 319 |
+
# if box:
|
| 320 |
+
# xmin, ymin, xmax, ymax = map(int, box)
|
| 321 |
+
# box_array = [[xmin, ymin, xmax, ymax]]
|
| 322 |
+
# print(f"📦 Using bounding box: {box_array}")
|
| 323 |
+
box = parse_bboxes(bboxes)
|
| 324 |
+
if box:
|
| 325 |
+
box_array = box
|
| 326 |
+
print(f"📦 Using bounding boxes: {box_array}")
|
| 327 |
+
|
| 328 |
+
try:
|
| 329 |
+
print(f"🔢 Counting - Image: {image_path}")
|
| 330 |
+
|
| 331 |
+
result = run_count(
|
| 332 |
+
COUNT_MODEL,
|
| 333 |
+
image_path,
|
| 334 |
+
box=box_array,
|
| 335 |
+
device=COUNT_DEVICE,
|
| 336 |
+
visualize=True
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if 'error' in result:
|
| 340 |
+
return None, f"❌ Counting failed: {result['error']}"
|
| 341 |
+
|
| 342 |
+
count = result['count']
|
| 343 |
+
density_map = result['density_map']
|
| 344 |
+
# save density map as temp file
|
| 345 |
+
temp_density_file = tempfile.NamedTemporaryFile(delete=False, suffix=".npy")
|
| 346 |
+
np.save(temp_density_file.name, density_map)
|
| 347 |
+
print(f"💾 Density map saved to {temp_density_file.name}")
|
| 348 |
+
|
| 349 |
+
|
| 350 |
+
try:
|
| 351 |
+
img = Image.open(image_path)
|
| 352 |
+
print("📷 Image mode:", img.mode, "size:", img.size)
|
| 353 |
+
except Exception as e:
|
| 354 |
+
print(f"❌ Failed to open image: {e}")
|
| 355 |
+
return None, None
|
| 356 |
+
|
| 357 |
+
try:
|
| 358 |
+
img_rgb = img.convert("RGB").resize(density_map.shape[::-1], resample=Image.BILINEAR)
|
| 359 |
+
img_np = np.array(img_rgb, dtype=np.float32)
|
| 360 |
+
img_np = (img_np - img_np.min()) / (img_np.max() - img_np.min() + 1e-8)
|
| 361 |
+
if img_np.max() > 1.5:
|
| 362 |
+
img_np = img_np / 255.0
|
| 363 |
+
except Exception as e:
|
| 364 |
+
print(f"❌ Error in image conversion/resizing: {e}")
|
| 365 |
+
return None, None
|
| 366 |
+
|
| 367 |
+
|
| 368 |
+
# Normalize density map to [0, 1]
|
| 369 |
+
density_normalized = density_map.copy()
|
| 370 |
+
if density_normalized.max() > 0:
|
| 371 |
+
density_normalized = (density_normalized - density_normalized.min()) / (density_normalized.max() - density_normalized.min())
|
| 372 |
+
|
| 373 |
+
# Apply colormap
|
| 374 |
+
cmap = cm.get_cmap("jet")
|
| 375 |
+
alpha = 0.3
|
| 376 |
+
density_colored = cmap(density_normalized)[:, :, :3] # RGB only, ignore alpha
|
| 377 |
+
|
| 378 |
+
# Create overlay
|
| 379 |
+
overlay = img_np.copy()
|
| 380 |
+
|
| 381 |
+
# Blend only where density is significant (optional: threshold)
|
| 382 |
+
threshold = 0.01 # Only overlay where density > 1% of max
|
| 383 |
+
significant_mask = density_normalized > threshold
|
| 384 |
+
|
| 385 |
+
overlay[significant_mask] = (1 - alpha) * overlay[significant_mask] + alpha * density_colored[significant_mask]
|
| 386 |
+
|
| 387 |
+
# Clip and convert to uint8
|
| 388 |
+
overlay = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
|
| 389 |
+
|
| 390 |
+
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
|
| 394 |
+
result_text = f"✅ Detected {round(count)} objects"
|
| 395 |
+
if use_box_choice == "Yes" and box:
|
| 396 |
+
result_text += f"\n📦 Using bounding box: {box_array}"
|
| 397 |
+
|
| 398 |
+
|
| 399 |
+
print(f"✅ Counting done - Count: {count:.1f}")
|
| 400 |
+
|
| 401 |
+
return Image.fromarray(overlay), temp_density_file.name, result_text
|
| 402 |
+
|
| 403 |
+
# return density_path, result_text
|
| 404 |
+
|
| 405 |
+
except Exception as e:
|
| 406 |
+
print(f"❌ Counting error: {e}")
|
| 407 |
+
import traceback
|
| 408 |
+
traceback.print_exc()
|
| 409 |
+
return None, f"❌ Counting failed: {str(e)}"
|
| 410 |
+
|
| 411 |
+
# ===== Tracking Functionality =====
|
| 412 |
+
def find_tif_dir(root_dir):
|
| 413 |
+
"""Recursively find the first directory containing .tif files"""
|
| 414 |
+
for dirpath, _, filenames in os.walk(root_dir):
|
| 415 |
+
if '__MACOSX' in dirpath:
|
| 416 |
+
continue
|
| 417 |
+
if any(f.lower().endswith('.tif') for f in filenames):
|
| 418 |
+
return dirpath
|
| 419 |
+
return None
|
| 420 |
+
|
| 421 |
+
def is_valid_tiff(filepath):
|
| 422 |
+
"""Check if a file is a valid TIFF image"""
|
| 423 |
+
try:
|
| 424 |
+
with Image.open(filepath) as img:
|
| 425 |
+
img.verify()
|
| 426 |
+
return True
|
| 427 |
+
except Exception as e:
|
| 428 |
+
return False
|
| 429 |
+
|
| 430 |
+
def find_valid_tif_dir(root_dir):
|
| 431 |
+
"""Recursively find the first directory containing valid .tif files"""
|
| 432 |
+
for dirpath, dirnames, filenames in os.walk(root_dir):
|
| 433 |
+
if '__MACOSX' in dirpath:
|
| 434 |
+
continue
|
| 435 |
+
|
| 436 |
+
potential_tifs = [
|
| 437 |
+
os.path.join(dirpath, f)
|
| 438 |
+
for f in filenames
|
| 439 |
+
if f.lower().endswith(('.tif', '.tiff')) and not f.startswith('._')
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
+
if not potential_tifs:
|
| 443 |
+
continue
|
| 444 |
+
|
| 445 |
+
valid_tifs = [f for f in potential_tifs if is_valid_tiff(f)]
|
| 446 |
+
|
| 447 |
+
if valid_tifs:
|
| 448 |
+
print(f"✅ Found {len(valid_tifs)} valid TIFF files in: {dirpath}")
|
| 449 |
+
return dirpath
|
| 450 |
+
|
| 451 |
+
return None
|
| 452 |
+
|
| 453 |
+
def create_ctc_results_zip(output_dir):
|
| 454 |
+
"""
|
| 455 |
+
Create a ZIP file with CTC format results
|
| 456 |
+
|
| 457 |
+
Parameters:
|
| 458 |
+
-----------
|
| 459 |
+
output_dir : str
|
| 460 |
+
Directory containing tracking results (res_track.txt, etc.)
|
| 461 |
+
|
| 462 |
+
Returns:
|
| 463 |
+
--------
|
| 464 |
+
zip_path : str
|
| 465 |
+
Path to created ZIP file
|
| 466 |
+
"""
|
| 467 |
+
# Create temp directory for ZIP
|
| 468 |
+
temp_zip_dir = tempfile.mkdtemp()
|
| 469 |
+
zip_filename = f"tracking_results_{time.strftime('%Y%m%d_%H%M%S')}.zip"
|
| 470 |
+
zip_path = os.path.join(temp_zip_dir, zip_filename)
|
| 471 |
+
|
| 472 |
+
print(f"📦 Creating results ZIP: {zip_path}")
|
| 473 |
+
|
| 474 |
+
# Create ZIP with all tracking results
|
| 475 |
+
with zipfile.ZipFile(zip_path, 'w', zipfile.ZIP_DEFLATED) as zipf:
|
| 476 |
+
# Add all files from output directory
|
| 477 |
+
for root, dirs, files in os.walk(output_dir):
|
| 478 |
+
for file in files:
|
| 479 |
+
file_path = os.path.join(root, file)
|
| 480 |
+
arcname = os.path.relpath(file_path, output_dir)
|
| 481 |
+
zipf.write(file_path, arcname)
|
| 482 |
+
print(f" 📄 Added: {arcname}")
|
| 483 |
+
|
| 484 |
+
# Add a README with summary
|
| 485 |
+
readme_content = f"""Tracking Results Summary
|
| 486 |
+
========================
|
| 487 |
+
|
| 488 |
+
Generated: {time.strftime('%Y-%m-%d %H:%M:%S')}
|
| 489 |
+
|
| 490 |
+
Files:
|
| 491 |
+
------
|
| 492 |
+
- res_track.txt: CTC format tracking data
|
| 493 |
+
Format: track_id start_frame end_frame parent_id
|
| 494 |
+
|
| 495 |
+
- Segmentation masks
|
| 496 |
+
|
| 497 |
+
For more information on CTC format:
|
| 498 |
+
http://celltrackingchallenge.net/
|
| 499 |
+
"""
|
| 500 |
+
zipf.writestr("README.txt", readme_content)
|
| 501 |
+
|
| 502 |
+
print(f"✅ ZIP created: {zip_path} ({os.path.getsize(zip_path) / 1024:.1f} KB)")
|
| 503 |
+
return zip_path
|
| 504 |
+
|
| 505 |
+
# 使用更智能的颜色分配 - 让相邻的ID颜色差异更大
|
| 506 |
+
def get_well_spaced_color(track_id, num_colors=256):
|
| 507 |
+
"""Generate well-spaced colors, using contrasting colors for adjacent IDs"""
|
| 508 |
+
# 使用质数跳跃来分散颜色
|
| 509 |
+
golden_ratio = 0.618033988749895
|
| 510 |
+
hue = (track_id * golden_ratio) % 1.0
|
| 511 |
+
|
| 512 |
+
# 使用高饱和度和明度
|
| 513 |
+
import colorsys
|
| 514 |
+
rgb = colorsys.hsv_to_rgb(hue, 0.9, 0.95)
|
| 515 |
+
return np.array(rgb)
|
| 516 |
+
|
| 517 |
+
|
| 518 |
+
def extract_first_frame(tif_dir):
|
| 519 |
+
"""
|
| 520 |
+
Extract the first frame from a directory of TIF files
|
| 521 |
+
|
| 522 |
+
Returns:
|
| 523 |
+
--------
|
| 524 |
+
first_frame_path : str
|
| 525 |
+
Path to the first TIF frame
|
| 526 |
+
"""
|
| 527 |
+
tif_files = natsorted(glob(os.path.join(tif_dir, "*.tif")) +
|
| 528 |
+
glob(os.path.join(tif_dir, "*.tiff")))
|
| 529 |
+
valid_tif_files = [f for f in tif_files
|
| 530 |
+
if not os.path.basename(f).startswith('._') and is_valid_tiff(f)]
|
| 531 |
+
|
| 532 |
+
if valid_tif_files:
|
| 533 |
+
return valid_tif_files[0]
|
| 534 |
+
return None
|
| 535 |
+
|
| 536 |
+
def create_tracking_visualization(tif_dir, output_dir, valid_tif_files):
|
| 537 |
+
"""
|
| 538 |
+
Create an animated GIF/video showing tracked objects with consistent colors
|
| 539 |
+
|
| 540 |
+
Parameters:
|
| 541 |
+
-----------
|
| 542 |
+
tif_dir : str
|
| 543 |
+
Directory containing input TIF frames
|
| 544 |
+
output_dir : str
|
| 545 |
+
Directory containing tracking results (masks)
|
| 546 |
+
valid_tif_files : list
|
| 547 |
+
List of valid TIF file paths
|
| 548 |
+
|
| 549 |
+
Returns:
|
| 550 |
+
--------
|
| 551 |
+
video_path : str
|
| 552 |
+
Path to generated visualization (GIF or first frame)
|
| 553 |
+
"""
|
| 554 |
+
import numpy as np
|
| 555 |
+
from matplotlib import colormaps
|
| 556 |
+
from skimage import measure
|
| 557 |
+
import tifffile
|
| 558 |
+
|
| 559 |
+
# Look for tracking mask files in output directory
|
| 560 |
+
# Common CTC formats: man_track*.tif, mask*.tif, or numbered masks
|
| 561 |
+
mask_files = natsorted(glob(os.path.join(output_dir, "mask*.tif")) +
|
| 562 |
+
glob(os.path.join(output_dir, "man_track*.tif")) +
|
| 563 |
+
glob(os.path.join(output_dir, "*.tif")))
|
| 564 |
+
|
| 565 |
+
if not mask_files:
|
| 566 |
+
print("⚠️ No mask files found in output directory")
|
| 567 |
+
# Return first frame as fallback
|
| 568 |
+
return valid_tif_files[0]
|
| 569 |
+
|
| 570 |
+
print(f"📊 Found {len(mask_files)} mask files")
|
| 571 |
+
|
| 572 |
+
# Create color map for consistent track IDs
|
| 573 |
+
# Use a colormap with many distinct colors
|
| 574 |
+
# try:
|
| 575 |
+
# cmap = colormaps.get_cmap("hsv")
|
| 576 |
+
# except:
|
| 577 |
+
# from matplotlib import cm
|
| 578 |
+
# cmap = cm.get_cmap("hsv")
|
| 579 |
+
|
| 580 |
+
frames = []
|
| 581 |
+
alpha = 0.3 # Transparency for overlay
|
| 582 |
+
|
| 583 |
+
# Process each frame
|
| 584 |
+
num_frames = min(len(valid_tif_files), len(mask_files))
|
| 585 |
+
for i in range(num_frames):
|
| 586 |
+
try:
|
| 587 |
+
# Load original image using tifffile (handles ZSTD compression)
|
| 588 |
+
try:
|
| 589 |
+
img_np = tifffile.imread(valid_tif_files[i])
|
| 590 |
+
|
| 591 |
+
# Normalize to [0, 1] range based on actual data type and values
|
| 592 |
+
if img_np.dtype == np.uint8:
|
| 593 |
+
img_np = img_np.astype(np.float32) / 255.0
|
| 594 |
+
elif img_np.dtype == np.uint16:
|
| 595 |
+
# Normalize uint16 to [0, 1] using actual min/max
|
| 596 |
+
img_min, img_max = img_np.min(), img_np.max()
|
| 597 |
+
if img_max > img_min:
|
| 598 |
+
img_np = (img_np.astype(np.float32) - img_min) / (img_max - img_min)
|
| 599 |
+
else:
|
| 600 |
+
img_np = img_np.astype(np.float32) / 65535.0
|
| 601 |
+
else:
|
| 602 |
+
# For float or other types, normalize based on actual range
|
| 603 |
+
img_np = img_np.astype(np.float32)
|
| 604 |
+
img_min, img_max = img_np.min(), img_np.max()
|
| 605 |
+
if img_max > img_min:
|
| 606 |
+
img_np = (img_np - img_min) / (img_max - img_min)
|
| 607 |
+
else:
|
| 608 |
+
img_np = np.clip(img_np, 0, 1)
|
| 609 |
+
|
| 610 |
+
# Convert to RGB if grayscale
|
| 611 |
+
if img_np.ndim == 2:
|
| 612 |
+
img_np = np.stack([img_np]*3, axis=-1)
|
| 613 |
+
img_np = img_np.astype(np.float32)
|
| 614 |
+
if img_np.max() > 1.5:
|
| 615 |
+
img_np = img_np / 255.0
|
| 616 |
+
except Exception as e:
|
| 617 |
+
print(f"⚠️ Error loading image frame {i}: {e}")
|
| 618 |
+
# Fallback to PIL
|
| 619 |
+
img = Image.open(valid_tif_files[i]).convert("RGB")
|
| 620 |
+
img_np = np.array(img, dtype=np.float32) / 255.0
|
| 621 |
+
|
| 622 |
+
# Load tracking mask using tifffile (handles ZSTD compression)
|
| 623 |
+
try:
|
| 624 |
+
mask = tifffile.imread(mask_files[i])
|
| 625 |
+
except Exception as e:
|
| 626 |
+
print(f"⚠️ Error loading mask frame {i}: {e}")
|
| 627 |
+
# Fallback to PIL
|
| 628 |
+
mask = np.array(Image.open(mask_files[i]))
|
| 629 |
+
|
| 630 |
+
# Resize mask to match image if needed
|
| 631 |
+
if mask.shape[:2] != img_np.shape[:2]:
|
| 632 |
+
from scipy.ndimage import zoom
|
| 633 |
+
zoom_factors = [img_np.shape[0] / mask.shape[0], img_np.shape[1] / mask.shape[1]]
|
| 634 |
+
mask = zoom(mask, zoom_factors, order=0).astype(mask.dtype)
|
| 635 |
+
|
| 636 |
+
# Create overlay
|
| 637 |
+
overlay = img_np.copy()
|
| 638 |
+
|
| 639 |
+
# Get unique track IDs (excluding background 0)
|
| 640 |
+
track_ids = np.unique(mask)
|
| 641 |
+
track_ids = track_ids[track_ids != 0]
|
| 642 |
+
|
| 643 |
+
# Color each tracked object
|
| 644 |
+
for track_id in track_ids:
|
| 645 |
+
# Create binary mask for this track
|
| 646 |
+
binary_mask = (mask == track_id)
|
| 647 |
+
|
| 648 |
+
# Get consistent color for this track ID
|
| 649 |
+
# color = np.array(cmap(int(track_id) % 256)[:3])
|
| 650 |
+
color = get_well_spaced_color(int(track_id))
|
| 651 |
+
|
| 652 |
+
# Blend color onto image
|
| 653 |
+
overlay[binary_mask] = (1 - alpha) * overlay[binary_mask] + alpha * color
|
| 654 |
+
|
| 655 |
+
# Draw contours (optional, adds yellow boundaries)
|
| 656 |
+
try:
|
| 657 |
+
contours = measure.find_contours(binary_mask.astype(np.uint8), 0.5)
|
| 658 |
+
for contour in contours:
|
| 659 |
+
contour = contour.astype(np.int32)
|
| 660 |
+
valid_y = np.clip(contour[:, 0], 0, overlay.shape[0] - 1)
|
| 661 |
+
valid_x = np.clip(contour[:, 1], 0, overlay.shape[1] - 1)
|
| 662 |
+
overlay[valid_y, valid_x] = [1.0, 1.0, 0.0] # Yellow contour
|
| 663 |
+
except:
|
| 664 |
+
pass # Skip contours if they fail
|
| 665 |
+
|
| 666 |
+
# Convert to uint8
|
| 667 |
+
overlay_uint8 = np.clip(overlay * 255.0, 0, 255).astype(np.uint8)
|
| 668 |
+
frames.append(Image.fromarray(overlay_uint8))
|
| 669 |
+
|
| 670 |
+
if i % 10 == 0 or i == num_frames - 1:
|
| 671 |
+
print(f" 📸 Processed frame {i+1}/{num_frames}")
|
| 672 |
+
|
| 673 |
+
except Exception as e:
|
| 674 |
+
print(f"⚠️ Error processing frame {i}: {e}")
|
| 675 |
+
import traceback
|
| 676 |
+
traceback.print_exc()
|
| 677 |
+
continue
|
| 678 |
+
|
| 679 |
+
if not frames:
|
| 680 |
+
print("⚠️ No frames were processed successfully")
|
| 681 |
+
return valid_tif_files[0]
|
| 682 |
+
|
| 683 |
+
# Save as animated GIF
|
| 684 |
+
try:
|
| 685 |
+
temp_gif = tempfile.NamedTemporaryFile(delete=False, suffix=".gif")
|
| 686 |
+
frames[0].save(
|
| 687 |
+
temp_gif.name,
|
| 688 |
+
save_all=True,
|
| 689 |
+
append_images=frames[1:],
|
| 690 |
+
duration=200, # 200ms per frame = 5fps
|
| 691 |
+
loop=0
|
| 692 |
+
)
|
| 693 |
+
temp_gif.close() # Close the file handle
|
| 694 |
+
print(f"✅ Created tracking visualization GIF: {temp_gif.name}")
|
| 695 |
+
print(f" Size: {os.path.getsize(temp_gif.name)} bytes, Frames: {len(frames)}")
|
| 696 |
+
return temp_gif.name
|
| 697 |
+
except Exception as e:
|
| 698 |
+
print(f"⚠️ Failed to create GIF: {e}")
|
| 699 |
+
import traceback
|
| 700 |
+
traceback.print_exc()
|
| 701 |
+
# Return first frame as static image fallback
|
| 702 |
+
try:
|
| 703 |
+
temp_img = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
| 704 |
+
frames[0].save(temp_img.name)
|
| 705 |
+
temp_img.close()
|
| 706 |
+
return temp_img.name
|
| 707 |
+
except:
|
| 708 |
+
return valid_tif_files[0]
|
| 709 |
+
|
| 710 |
+
# @spaces.GPU
|
| 711 |
+
def track_video_handler(use_box_choice, first_frame_annot, zip_file_obj):
|
| 712 |
+
"""
|
| 713 |
+
支持 ZIP 压缩包上传的 Tracking 处理函数 - 支持首帧边界框
|
| 714 |
+
|
| 715 |
+
Parameters:
|
| 716 |
+
-----------
|
| 717 |
+
use_box_choice : str
|
| 718 |
+
"Yes" or "No" - 是否使用边界框
|
| 719 |
+
first_frame_annot : tuple or None
|
| 720 |
+
(image_path, bboxes) from BBoxAnnotator, only used if user annotated first frame
|
| 721 |
+
zip_file_obj : File
|
| 722 |
+
Uploaded ZIP file containing TIF sequence
|
| 723 |
+
"""
|
| 724 |
+
if zip_file_obj is None:
|
| 725 |
+
return None, "⚠️ 请上传包含视频帧的压缩包 (.zip)", None, None
|
| 726 |
+
|
| 727 |
+
temp_dir = None
|
| 728 |
+
output_temp_dir = None
|
| 729 |
+
|
| 730 |
+
try:
|
| 731 |
+
# Parse bounding box if provided
|
| 732 |
+
box_array = None
|
| 733 |
+
if use_box_choice == "Yes" and first_frame_annot is not None:
|
| 734 |
+
if isinstance(first_frame_annot, (list, tuple)) and len(first_frame_annot) > 1:
|
| 735 |
+
bboxes = first_frame_annot[1]
|
| 736 |
+
if bboxes:
|
| 737 |
+
# box = parse_first_bbox(bboxes)
|
| 738 |
+
# if box:
|
| 739 |
+
# xmin, ymin, xmax, ymax = map(int, box)
|
| 740 |
+
# box_array = [[xmin, ymin, xmax, ymax]]
|
| 741 |
+
# print(f"📦 Using bounding box: {box_array}")
|
| 742 |
+
box = parse_bboxes(bboxes)
|
| 743 |
+
if box:
|
| 744 |
+
box_array = box
|
| 745 |
+
print(f"📦 Using bounding boxes: {box_array}")
|
| 746 |
+
|
| 747 |
+
# Extract input ZIP
|
| 748 |
+
temp_dir = tempfile.mkdtemp()
|
| 749 |
+
print(f"\n📦 Extracting to temporary directory: {temp_dir}")
|
| 750 |
+
|
| 751 |
+
with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref:
|
| 752 |
+
extracted_count = 0
|
| 753 |
+
skipped_count = 0
|
| 754 |
+
|
| 755 |
+
for member in zip_ref.namelist():
|
| 756 |
+
basename = os.path.basename(member)
|
| 757 |
+
|
| 758 |
+
if ('__MACOSX' in member or
|
| 759 |
+
basename.startswith('._') or
|
| 760 |
+
basename.startswith('.DS_Store') or
|
| 761 |
+
member.endswith('/')):
|
| 762 |
+
skipped_count += 1
|
| 763 |
+
continue
|
| 764 |
+
|
| 765 |
+
try:
|
| 766 |
+
zip_ref.extract(member, temp_dir)
|
| 767 |
+
extracted_count += 1
|
| 768 |
+
if basename.lower().endswith(('.tif', '.tiff')):
|
| 769 |
+
print(f"📄 Extracted TIFF: {basename}")
|
| 770 |
+
except Exception as e:
|
| 771 |
+
print(f"⚠️ Failed to extract {member}: {e}")
|
| 772 |
+
|
| 773 |
+
print(f"\n📊 Extracted: {extracted_count} files, Skipped: {skipped_count} files")
|
| 774 |
+
|
| 775 |
+
# Find valid TIFF directory
|
| 776 |
+
tif_dir = find_valid_tif_dir(temp_dir)
|
| 777 |
+
|
| 778 |
+
if tif_dir is None:
|
| 779 |
+
return None, "❌ Did not find valid TIF directory", None, None
|
| 780 |
+
|
| 781 |
+
# Validate TIFF files
|
| 782 |
+
tif_files = natsorted(glob(os.path.join(tif_dir, "*.tif")) +
|
| 783 |
+
glob(os.path.join(tif_dir, "*.tiff")))
|
| 784 |
+
valid_tif_files = [f for f in tif_files
|
| 785 |
+
if not os.path.basename(f).startswith('._') and is_valid_tiff(f)]
|
| 786 |
+
|
| 787 |
+
if len(valid_tif_files) == 0:
|
| 788 |
+
return None, "❌ Did not find valid TIF files", None, None
|
| 789 |
+
|
| 790 |
+
print(f"📈 Using {len(valid_tif_files)} TIF files")
|
| 791 |
+
|
| 792 |
+
# Store paths for later visualization
|
| 793 |
+
first_frame_path = valid_tif_files[0]
|
| 794 |
+
|
| 795 |
+
# Create temporary output directory for CTC results
|
| 796 |
+
output_temp_dir = tempfile.mkdtemp()
|
| 797 |
+
print(f"💾 CTC-format results will be saved to: {output_temp_dir}")
|
| 798 |
+
|
| 799 |
+
# Run tracking with optional bounding box
|
| 800 |
+
result = run_track(
|
| 801 |
+
TRACK_MODEL,
|
| 802 |
+
video_dir=tif_dir,
|
| 803 |
+
box=box_array, # Pass bounding box if specified
|
| 804 |
+
device=TRACK_DEVICE,
|
| 805 |
+
output_dir=output_temp_dir
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
if 'error' in result:
|
| 809 |
+
return None, f"❌ Tracking failed: {result['error']}", None, None
|
| 810 |
+
|
| 811 |
+
# Create visualization video of tracked objects
|
| 812 |
+
print("\n🎬 Creating tracking visualization...")
|
| 813 |
+
try:
|
| 814 |
+
tracking_video = create_tracking_visualization(
|
| 815 |
+
tif_dir,
|
| 816 |
+
output_temp_dir,
|
| 817 |
+
valid_tif_files
|
| 818 |
+
)
|
| 819 |
+
except Exception as e:
|
| 820 |
+
print(f"⚠️ Failed to create visualization: {e}")
|
| 821 |
+
import traceback
|
| 822 |
+
traceback.print_exc()
|
| 823 |
+
# Fallback to first frame if visualization fails
|
| 824 |
+
try:
|
| 825 |
+
tracking_video = Image.open(first_frame_path)
|
| 826 |
+
except:
|
| 827 |
+
tracking_video = None
|
| 828 |
+
|
| 829 |
+
# Create downloadable ZIP with results
|
| 830 |
+
try:
|
| 831 |
+
results_zip = create_ctc_results_zip(output_temp_dir)
|
| 832 |
+
except Exception as e:
|
| 833 |
+
print(f"⚠️ Failed to create ZIP: {e}")
|
| 834 |
+
results_zip = None
|
| 835 |
+
|
| 836 |
+
bbox_info = ""
|
| 837 |
+
if box_array:
|
| 838 |
+
bbox_info = f"\n🔲 Using bounding box: [{box_array[0][0]}, {box_array[0][1]}, {box_array[0][2]}, {box_array[0][3]}]"
|
| 839 |
+
|
| 840 |
+
result_text = f"""✅ Tracking completed!
|
| 841 |
+
|
| 842 |
+
🖼️ Processed frames: {len(valid_tif_files)}{bbox_info}
|
| 843 |
+
|
| 844 |
+
📥 Click the button below to download CTC-format results
|
| 845 |
+
The results include:
|
| 846 |
+
- res_track.txt (CTC-format tracking data)
|
| 847 |
+
- Other tracking-related files
|
| 848 |
+
- README.txt (Results description)
|
| 849 |
+
"""
|
| 850 |
+
|
| 851 |
+
if use_box_choice == "Yes" and box:
|
| 852 |
+
result_text += f"\n📦 Using bounding box: {box_array}"
|
| 853 |
+
|
| 854 |
+
print(f"\n✅ Tracking completed")
|
| 855 |
+
|
| 856 |
+
# Clean up input temp directory (keep output temp for download)
|
| 857 |
+
if temp_dir:
|
| 858 |
+
try:
|
| 859 |
+
shutil.rmtree(temp_dir)
|
| 860 |
+
print(f"🗑️ Cleared input temp directory")
|
| 861 |
+
except:
|
| 862 |
+
pass
|
| 863 |
+
|
| 864 |
+
return results_zip, result_text, gr.update(visible=True), tracking_video
|
| 865 |
+
|
| 866 |
+
except zipfile.BadZipFile:
|
| 867 |
+
return None, "❌ Not a valid ZIP file", None, None
|
| 868 |
+
except Exception as e:
|
| 869 |
+
import traceback
|
| 870 |
+
traceback.print_exc()
|
| 871 |
+
|
| 872 |
+
# Clean up on error
|
| 873 |
+
for d in [temp_dir, output_temp_dir]:
|
| 874 |
+
if d:
|
| 875 |
+
try:
|
| 876 |
+
shutil.rmtree(d)
|
| 877 |
+
except:
|
| 878 |
+
pass
|
| 879 |
+
return None, f"❌ Tracking failed: {str(e)}", None, None
|
| 880 |
+
|
| 881 |
+
|
| 882 |
+
|
| 883 |
+
# ===== 示例图像 =====
|
| 884 |
+
example_images_seg = [f for f in glob("example_imgs/seg/*")]
|
| 885 |
+
# ["example_imgs/seg/003_img.png", "example_imgs/seg/1977_Well_F-5_Field_1.png"]
|
| 886 |
+
example_images_cnt = [f for f in glob("example_imgs/cnt/*")]
|
| 887 |
+
example_tracking_zips = [f for f in glob("example_imgs/tra/*.zip")]
|
| 888 |
+
|
| 889 |
+
# ===== Gradio UI =====
|
| 890 |
+
with gr.Blocks(
|
| 891 |
+
title="Microscopy Analysis Suite",
|
| 892 |
+
theme=gr.themes.Soft(),
|
| 893 |
+
css="""
|
| 894 |
+
.tabs button {
|
| 895 |
+
font-size: 18px !important;
|
| 896 |
+
font-weight: 600 !important;
|
| 897 |
+
padding: 12px 20px !important;
|
| 898 |
+
}
|
| 899 |
+
.uniform-height {
|
| 900 |
+
height: 500px !important;
|
| 901 |
+
display: flex !important;
|
| 902 |
+
align-items: center !important;
|
| 903 |
+
justify-content: center !important;
|
| 904 |
+
}
|
| 905 |
+
|
| 906 |
+
.uniform-height img,
|
| 907 |
+
.uniform-height canvas {
|
| 908 |
+
max-height: 500px !important;
|
| 909 |
+
object-fit: contain !important;
|
| 910 |
+
}
|
| 911 |
+
|
| 912 |
+
/* 强制密度图容器和图片高度 */
|
| 913 |
+
#density_map_output {
|
| 914 |
+
height: 500px !important;
|
| 915 |
+
}
|
| 916 |
+
|
| 917 |
+
#density_map_output .image-container {
|
| 918 |
+
height: 500px !important;
|
| 919 |
+
}
|
| 920 |
+
|
| 921 |
+
#density_map_output img {
|
| 922 |
+
height: 480px !important;
|
| 923 |
+
width: auto !important;
|
| 924 |
+
max-width: 90% !important;
|
| 925 |
+
object-fit: contain !important;
|
| 926 |
+
}
|
| 927 |
+
"""
|
| 928 |
+
) as demo:
|
| 929 |
+
gr.Markdown(
|
| 930 |
+
"""
|
| 931 |
+
# 🔬 Microscopy Image Analysis Suite
|
| 932 |
+
|
| 933 |
+
Supporting three key tasks:
|
| 934 |
+
- 🎨 **Segmentation**: Instance segmentation of microscopic objects
|
| 935 |
+
- 🔢 **Counting**: Counting microscopic objects based on density maps
|
| 936 |
+
- 🎬 **Tracking**: Tracking microscopic objects in video sequences
|
| 937 |
+
"""
|
| 938 |
+
)
|
| 939 |
+
|
| 940 |
+
# 全局状态
|
| 941 |
+
current_query_id = gr.State(str(uuid.uuid4()))
|
| 942 |
+
user_uploaded_examples = gr.State(example_images_seg.copy()) # 初始化时包含原始示例
|
| 943 |
+
|
| 944 |
+
with gr.Tabs():
|
| 945 |
+
# ===== Tab 1: Segmentation =====
|
| 946 |
+
with gr.Tab("🎨 Segmentation"):
|
| 947 |
+
gr.Markdown("## Instance Segmentation of Microscopic Objects")
|
| 948 |
+
gr.Markdown(
|
| 949 |
+
"""
|
| 950 |
+
**Instructions:**
|
| 951 |
+
1. Upload an image or select an example image (supports various formats: .png, .jpg, .tif)
|
| 952 |
+
2. (Optional) Specify a target object with a bounding box and select "Yes", or click "Run Segmentation" directly
|
| 953 |
+
3. Click "Run Segmentation"
|
| 954 |
+
4. View the segmentation results, download the original predicted mask (.tif format); if needed, click "Clear Selection" to choose a new image
|
| 955 |
+
|
| 956 |
+
🤘 Rate and submit feedback to help us improve the model!
|
| 957 |
+
"""
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
with gr.Row():
|
| 961 |
+
with gr.Column(scale=1):
|
| 962 |
+
annotator = BBoxAnnotator(
|
| 963 |
+
label="🖼️ Upload Image (Optional: Provide a Bounding Box)",
|
| 964 |
+
categories=["cell"],
|
| 965 |
+
)
|
| 966 |
+
|
| 967 |
+
# Example Images Gallery
|
| 968 |
+
example_gallery = gr.Gallery(
|
| 969 |
+
label="📁 Example Image Gallery",
|
| 970 |
+
columns=len(example_images_seg),
|
| 971 |
+
rows=1,
|
| 972 |
+
height=120,
|
| 973 |
+
object_fit="cover",
|
| 974 |
+
show_download_button=False
|
| 975 |
+
)
|
| 976 |
+
|
| 977 |
+
|
| 978 |
+
with gr.Row():
|
| 979 |
+
use_box_radio = gr.Radio(
|
| 980 |
+
choices=["Yes", "No"],
|
| 981 |
+
value="No",
|
| 982 |
+
label="🔲 Specify Bounding Box?"
|
| 983 |
+
)
|
| 984 |
+
with gr.Row():
|
| 985 |
+
run_seg_btn = gr.Button("▶️ Run Segmentation", variant="primary", size="lg")
|
| 986 |
+
clear_btn = gr.Button("🔄 Clear Selection", variant="secondary")
|
| 987 |
+
|
| 988 |
+
# Upload Example Image
|
| 989 |
+
image_uploader = gr.Image(
|
| 990 |
+
label="➕ Upload New Example Image to Gallery",
|
| 991 |
+
type="filepath"
|
| 992 |
+
)
|
| 993 |
+
|
| 994 |
+
|
| 995 |
+
with gr.Column(scale=2):
|
| 996 |
+
seg_output = gr.Image(
|
| 997 |
+
type="pil",
|
| 998 |
+
label="📸 Segmentation Result",
|
| 999 |
+
elem_classes="uniform-height"
|
| 1000 |
+
)
|
| 1001 |
+
|
| 1002 |
+
# Download Original Prediction
|
| 1003 |
+
download_mask_btn = gr.File(
|
| 1004 |
+
label="📥 Download Original Prediction (.tif format)",
|
| 1005 |
+
visible=True,
|
| 1006 |
+
height=40,
|
| 1007 |
+
)
|
| 1008 |
+
|
| 1009 |
+
# Satisfaction Rating
|
| 1010 |
+
score_slider = gr.Slider(
|
| 1011 |
+
minimum=1,
|
| 1012 |
+
maximum=5,
|
| 1013 |
+
step=1,
|
| 1014 |
+
value=5,
|
| 1015 |
+
label="🌟 Satisfaction Rating (1-5)"
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
# Feedback Textbox
|
| 1019 |
+
feedback_box = gr.Textbox(
|
| 1020 |
+
placeholder="Please enter your feedback...",
|
| 1021 |
+
lines=2,
|
| 1022 |
+
label="💬 Feedback"
|
| 1023 |
+
)
|
| 1024 |
+
|
| 1025 |
+
# Submit Button
|
| 1026 |
+
submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary")
|
| 1027 |
+
|
| 1028 |
+
feedback_status = gr.Textbox(
|
| 1029 |
+
label="✅ Submission Status",
|
| 1030 |
+
lines=1,
|
| 1031 |
+
visible=False
|
| 1032 |
+
)
|
| 1033 |
+
|
| 1034 |
+
# 绑定事件: 运行分割
|
| 1035 |
+
run_seg_btn.click(
|
| 1036 |
+
fn=segment_with_choice,
|
| 1037 |
+
inputs=[use_box_radio, annotator],
|
| 1038 |
+
outputs=[seg_output, download_mask_btn]
|
| 1039 |
+
)
|
| 1040 |
+
|
| 1041 |
+
# 清空按钮事件
|
| 1042 |
+
clear_btn.click(
|
| 1043 |
+
fn=lambda: None,
|
| 1044 |
+
inputs=None,
|
| 1045 |
+
outputs=annotator
|
| 1046 |
+
)
|
| 1047 |
+
|
| 1048 |
+
# 初始化Gallery显示
|
| 1049 |
+
demo.load(
|
| 1050 |
+
fn=lambda: example_images_seg.copy(),
|
| 1051 |
+
outputs=example_gallery
|
| 1052 |
+
)
|
| 1053 |
+
|
| 1054 |
+
# 绑定事件: 上传示例图片
|
| 1055 |
+
def add_to_gallery(img_path, current_imgs):
|
| 1056 |
+
if not img_path:
|
| 1057 |
+
return current_imgs
|
| 1058 |
+
try:
|
| 1059 |
+
if img_path not in current_imgs:
|
| 1060 |
+
current_imgs.append(img_path)
|
| 1061 |
+
return current_imgs
|
| 1062 |
+
except:
|
| 1063 |
+
return current_imgs
|
| 1064 |
+
|
| 1065 |
+
image_uploader.change(
|
| 1066 |
+
fn=add_to_gallery,
|
| 1067 |
+
inputs=[image_uploader, user_uploaded_examples],
|
| 1068 |
+
outputs=user_uploaded_examples
|
| 1069 |
+
).then(
|
| 1070 |
+
fn=lambda imgs: imgs,
|
| 1071 |
+
inputs=user_uploaded_examples,
|
| 1072 |
+
outputs=example_gallery
|
| 1073 |
+
)
|
| 1074 |
+
|
| 1075 |
+
# 绑定事件: 点击Gallery加载
|
| 1076 |
+
def load_from_gallery(evt: gr.SelectData, all_imgs):
|
| 1077 |
+
if evt.index is not None and evt.index < len(all_imgs):
|
| 1078 |
+
return all_imgs[evt.index]
|
| 1079 |
+
return None
|
| 1080 |
+
|
| 1081 |
+
example_gallery.select(
|
| 1082 |
+
fn=load_from_gallery,
|
| 1083 |
+
inputs=user_uploaded_examples,
|
| 1084 |
+
outputs=annotator
|
| 1085 |
+
)
|
| 1086 |
+
|
| 1087 |
+
# 绑定事件: 提交反馈
|
| 1088 |
+
def submit_user_feedback(query_id, score, comment, annot_val):
|
| 1089 |
+
try:
|
| 1090 |
+
img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
|
| 1091 |
+
bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else []
|
| 1092 |
+
|
| 1093 |
+
# save_feedback(
|
| 1094 |
+
# query_id=query_id,
|
| 1095 |
+
# feedback_type=f"score_{int(score)}",
|
| 1096 |
+
# feedback_text=comment,
|
| 1097 |
+
# img_path=img_path,
|
| 1098 |
+
# bboxes=bboxes
|
| 1099 |
+
# )
|
| 1100 |
+
# 使用 HF 存储
|
| 1101 |
+
save_feedback_to_hf(
|
| 1102 |
+
query_id=query_id,
|
| 1103 |
+
feedback_type=f"score_{int(score)}",
|
| 1104 |
+
feedback_text=comment,
|
| 1105 |
+
img_path=img_path,
|
| 1106 |
+
bboxes=bboxes
|
| 1107 |
+
)
|
| 1108 |
+
return "✅ Feedback submitted, thank you!", gr.update(visible=True)
|
| 1109 |
+
except Exception as e:
|
| 1110 |
+
return f"❌ Submission failed: {str(e)}", gr.update(visible=True)
|
| 1111 |
+
|
| 1112 |
+
submit_feedback_btn.click(
|
| 1113 |
+
fn=submit_user_feedback,
|
| 1114 |
+
inputs=[current_query_id, score_slider, feedback_box, annotator],
|
| 1115 |
+
outputs=[feedback_status, feedback_status]
|
| 1116 |
+
)
|
| 1117 |
+
|
| 1118 |
+
# ===== Tab 2: Counting =====
|
| 1119 |
+
with gr.Tab("🔢 Counting"):
|
| 1120 |
+
gr.Markdown("## Microscopy Object Counting Analysis")
|
| 1121 |
+
gr.Markdown(
|
| 1122 |
+
"""
|
| 1123 |
+
**Usage Instructions:**
|
| 1124 |
+
1. Upload an image or select an example image (supports multiple formats: .png, .jpg, .tif)
|
| 1125 |
+
2. (Optional) Specify a target object with a bounding box and select "Yes", or click "Run Counting" directly
|
| 1126 |
+
3. Click "Run Counting"
|
| 1127 |
+
4. View the density map, download the original prediction (.npy format); if needed, click "Clear Selection" to choose a new image to run
|
| 1128 |
+
|
| 1129 |
+
🤘 Rate and submit feedback to help us improve the model!
|
| 1130 |
+
"""
|
| 1131 |
+
)
|
| 1132 |
+
|
| 1133 |
+
with gr.Row():
|
| 1134 |
+
with gr.Column(scale=1):
|
| 1135 |
+
count_annotator = BBoxAnnotator(
|
| 1136 |
+
label="🖼️ Upload Image (Optional: Provide a Bounding Box)",
|
| 1137 |
+
categories=["cell"],
|
| 1138 |
+
)
|
| 1139 |
+
|
| 1140 |
+
# Example gallery with "add" functionality
|
| 1141 |
+
with gr.Row():
|
| 1142 |
+
count_example_gallery = gr.Gallery(
|
| 1143 |
+
label="📁 Example Image Gallery",
|
| 1144 |
+
columns=len(example_images_cnt),
|
| 1145 |
+
rows=1,
|
| 1146 |
+
object_fit="cover",
|
| 1147 |
+
height=120,
|
| 1148 |
+
value=example_images_cnt.copy(), # Initialize with examples
|
| 1149 |
+
show_download_button=False
|
| 1150 |
+
)
|
| 1151 |
+
|
| 1152 |
+
|
| 1153 |
+
with gr.Row():
|
| 1154 |
+
count_use_box_radio = gr.Radio(
|
| 1155 |
+
choices=["Yes", "No"],
|
| 1156 |
+
value="No",
|
| 1157 |
+
label="🔲 Specify Bounding Box?"
|
| 1158 |
+
)
|
| 1159 |
+
|
| 1160 |
+
with gr.Row():
|
| 1161 |
+
count_btn = gr.Button("▶️ Run Counting", variant="primary", size="lg")
|
| 1162 |
+
clear_btn = gr.Button("🔄 Clear Selection", variant="secondary")
|
| 1163 |
+
|
| 1164 |
+
# Add button to upload new examples
|
| 1165 |
+
with gr.Row():
|
| 1166 |
+
count_image_uploader = gr.File(
|
| 1167 |
+
label="➕ Add Example Image to Gallery",
|
| 1168 |
+
file_types=["image"],
|
| 1169 |
+
type="filepath"
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
|
| 1173 |
+
with gr.Column(scale=2):
|
| 1174 |
+
count_output = gr.Image(
|
| 1175 |
+
label="📸 Density Map",
|
| 1176 |
+
type="filepath",
|
| 1177 |
+
elem_id="density_map_output"
|
| 1178 |
+
|
| 1179 |
+
)
|
| 1180 |
+
count_status = gr.Textbox(
|
| 1181 |
+
label="📊 Statistics",
|
| 1182 |
+
lines=2
|
| 1183 |
+
)
|
| 1184 |
+
download_density_btn = gr.File(
|
| 1185 |
+
label="📥 Download Original Prediction (.npy format)",
|
| 1186 |
+
visible=True
|
| 1187 |
+
)
|
| 1188 |
+
|
| 1189 |
+
# Satisfaction rating
|
| 1190 |
+
score_slider = gr.Slider(
|
| 1191 |
+
minimum=1,
|
| 1192 |
+
maximum=5,
|
| 1193 |
+
step=1,
|
| 1194 |
+
value=5,
|
| 1195 |
+
label="🌟 Satisfaction Rating (1-5)"
|
| 1196 |
+
)
|
| 1197 |
+
|
| 1198 |
+
# Feedback textbox
|
| 1199 |
+
feedback_box = gr.Textbox(
|
| 1200 |
+
placeholder="Please enter your feedback...",
|
| 1201 |
+
lines=2,
|
| 1202 |
+
label="💬 Feedback"
|
| 1203 |
+
)
|
| 1204 |
+
|
| 1205 |
+
# Submit button
|
| 1206 |
+
submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary")
|
| 1207 |
+
|
| 1208 |
+
feedback_status = gr.Textbox(
|
| 1209 |
+
label="✅ Submission Status",
|
| 1210 |
+
lines=1,
|
| 1211 |
+
visible=False
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
# State for managing gallery images
|
| 1215 |
+
count_user_examples = gr.State(example_images_cnt.copy())
|
| 1216 |
+
|
| 1217 |
+
# Function to add image to gallery
|
| 1218 |
+
def add_to_count_gallery(new_img_file, current_imgs):
|
| 1219 |
+
"""Add uploaded image to gallery"""
|
| 1220 |
+
if new_img_file is None:
|
| 1221 |
+
return current_imgs, current_imgs
|
| 1222 |
+
|
| 1223 |
+
try:
|
| 1224 |
+
# Add new image path to list
|
| 1225 |
+
if new_img_file not in current_imgs:
|
| 1226 |
+
current_imgs.append(new_img_file)
|
| 1227 |
+
print(f"✅ Added image to gallery: {new_img_file}")
|
| 1228 |
+
except Exception as e:
|
| 1229 |
+
print(f"⚠️ Failed to add image: {e}")
|
| 1230 |
+
|
| 1231 |
+
return current_imgs, current_imgs
|
| 1232 |
+
|
| 1233 |
+
# When user uploads a new image file
|
| 1234 |
+
count_image_uploader.upload(
|
| 1235 |
+
fn=add_to_count_gallery,
|
| 1236 |
+
inputs=[count_image_uploader, count_user_examples],
|
| 1237 |
+
outputs=[count_user_examples, count_example_gallery]
|
| 1238 |
+
)
|
| 1239 |
+
|
| 1240 |
+
# When user selects from gallery, load into annotator
|
| 1241 |
+
def load_from_count_gallery(evt: gr.SelectData, all_imgs):
|
| 1242 |
+
"""Load selected image from gallery into annotator"""
|
| 1243 |
+
if evt.index is not None and evt.index < len(all_imgs):
|
| 1244 |
+
selected_img = all_imgs[evt.index]
|
| 1245 |
+
print(f"📸 Loading image from gallery: {selected_img}")
|
| 1246 |
+
return selected_img
|
| 1247 |
+
return None
|
| 1248 |
+
|
| 1249 |
+
count_example_gallery.select(
|
| 1250 |
+
fn=load_from_count_gallery,
|
| 1251 |
+
inputs=count_user_examples,
|
| 1252 |
+
outputs=count_annotator
|
| 1253 |
+
)
|
| 1254 |
+
|
| 1255 |
+
# Run counting
|
| 1256 |
+
count_btn.click(
|
| 1257 |
+
fn=count_cells_handler,
|
| 1258 |
+
inputs=[count_use_box_radio, count_annotator],
|
| 1259 |
+
outputs=[count_output, download_density_btn, count_status]
|
| 1260 |
+
)
|
| 1261 |
+
|
| 1262 |
+
# 清空按钮事件
|
| 1263 |
+
clear_btn.click(
|
| 1264 |
+
fn=lambda: None,
|
| 1265 |
+
inputs=None,
|
| 1266 |
+
outputs=count_annotator
|
| 1267 |
+
)
|
| 1268 |
+
|
| 1269 |
+
# 绑定事件: 提交反馈
|
| 1270 |
+
def submit_user_feedback(query_id, score, comment, annot_val):
|
| 1271 |
+
try:
|
| 1272 |
+
img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
|
| 1273 |
+
bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else []
|
| 1274 |
+
|
| 1275 |
+
# save_feedback(
|
| 1276 |
+
# query_id=query_id,
|
| 1277 |
+
# feedback_type=f"score_{int(score)}",
|
| 1278 |
+
# feedback_text=comment,
|
| 1279 |
+
# img_path=img_path,
|
| 1280 |
+
# bboxes=bboxes
|
| 1281 |
+
# )
|
| 1282 |
+
# 使用 HF 存储
|
| 1283 |
+
save_feedback_to_hf(
|
| 1284 |
+
query_id=query_id,
|
| 1285 |
+
feedback_type=f"score_{int(score)}",
|
| 1286 |
+
feedback_text=comment,
|
| 1287 |
+
img_path=img_path,
|
| 1288 |
+
bboxes=bboxes
|
| 1289 |
+
)
|
| 1290 |
+
return "✅ Feedback submitted successfully, thank you!", gr.update(visible=True)
|
| 1291 |
+
except Exception as e:
|
| 1292 |
+
return f"❌ Submission failed: {str(e)}", gr.update(visible=True)
|
| 1293 |
+
|
| 1294 |
+
submit_feedback_btn.click(
|
| 1295 |
+
fn=submit_user_feedback,
|
| 1296 |
+
inputs=[current_query_id, score_slider, feedback_box, annotator],
|
| 1297 |
+
outputs=[feedback_status, feedback_status]
|
| 1298 |
+
)
|
| 1299 |
+
|
| 1300 |
+
# ===== Tab 3: Tracking =====
|
| 1301 |
+
with gr.Tab("🎬 Tracking"):
|
| 1302 |
+
gr.Markdown("## Microscopy Object Video Tracking - Supports ZIP Upload")
|
| 1303 |
+
gr.Markdown(
|
| 1304 |
+
"""
|
| 1305 |
+
**Instructions:**
|
| 1306 |
+
1. Upload a ZIP file or select from the example library. The ZIP should contain a sequence of TIF images named in chronological order (e.g., t000.tif, t001.tif...)
|
| 1307 |
+
2. (Optional) Specify a target object with a bounding box on the first frame and select "Yes", or click "Run Tracking" directly
|
| 1308 |
+
3. Click "Run Tracking"
|
| 1309 |
+
4. Download the CTC format results; if needed, click "Clear Selection" to choose a new ZIP file to run
|
| 1310 |
+
|
| 1311 |
+
🤘 Rate and submit feedback to help us improve the model!
|
| 1312 |
+
|
| 1313 |
+
"""
|
| 1314 |
+
)
|
| 1315 |
+
|
| 1316 |
+
with gr.Row():
|
| 1317 |
+
with gr.Column(scale=1):
|
| 1318 |
+
track_zip_upload = gr.File(
|
| 1319 |
+
label="📦 Upload Image Sequence in ZIP File",
|
| 1320 |
+
file_types=[".zip"]
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
# First frame annotation for bounding box
|
| 1324 |
+
track_first_frame_annotator = BBoxAnnotator(
|
| 1325 |
+
label="🖼️ (Optional) First Frame Bounding Box Annotation",
|
| 1326 |
+
categories=["cell"],
|
| 1327 |
+
visible=False, # Hidden initially
|
| 1328 |
+
)
|
| 1329 |
+
|
| 1330 |
+
# Example ZIP gallery
|
| 1331 |
+
track_example_gallery = gr.Gallery(
|
| 1332 |
+
label="📁 Example Video Gallery (Click to Select)",
|
| 1333 |
+
columns=10,
|
| 1334 |
+
rows=1,
|
| 1335 |
+
height=120,
|
| 1336 |
+
object_fit="contain",
|
| 1337 |
+
show_download_button=False
|
| 1338 |
+
)
|
| 1339 |
+
|
| 1340 |
+
with gr.Row():
|
| 1341 |
+
track_use_box_radio = gr.Radio(
|
| 1342 |
+
choices=["Yes", "No"],
|
| 1343 |
+
value="No",
|
| 1344 |
+
label="🔲 Specify Bounding Box?"
|
| 1345 |
+
)
|
| 1346 |
+
|
| 1347 |
+
with gr.Row():
|
| 1348 |
+
track_btn = gr.Button("▶️ Run Tracking", variant="primary", size="lg")
|
| 1349 |
+
clear_btn = gr.Button("🔄 Clear Selection", variant="secondary")
|
| 1350 |
+
|
| 1351 |
+
# Add to gallery button
|
| 1352 |
+
track_gallery_upload = gr.File(
|
| 1353 |
+
label="➕ Add ZIP to Example Gallery",
|
| 1354 |
+
file_types=[".zip"],
|
| 1355 |
+
type="filepath"
|
| 1356 |
+
)
|
| 1357 |
+
|
| 1358 |
+
with gr.Column(scale=2):
|
| 1359 |
+
track_first_frame_preview = gr.Image(
|
| 1360 |
+
label="📸 Tracking Visualization",
|
| 1361 |
+
type="filepath",
|
| 1362 |
+
# height=400,
|
| 1363 |
+
elem_classes="uniform-height",
|
| 1364 |
+
interactive=False
|
| 1365 |
+
)
|
| 1366 |
+
|
| 1367 |
+
track_output = gr.Textbox(
|
| 1368 |
+
label="📊 Tracking Information",
|
| 1369 |
+
lines=8,
|
| 1370 |
+
interactive=False
|
| 1371 |
+
)
|
| 1372 |
+
|
| 1373 |
+
track_download = gr.File(
|
| 1374 |
+
label="📥 Download Tracking Results (CTC Format)",
|
| 1375 |
+
visible=False
|
| 1376 |
+
)
|
| 1377 |
+
|
| 1378 |
+
# Satisfaction rating
|
| 1379 |
+
score_slider = gr.Slider(
|
| 1380 |
+
minimum=1,
|
| 1381 |
+
maximum=5,
|
| 1382 |
+
step=1,
|
| 1383 |
+
value=5,
|
| 1384 |
+
label="🌟 Satisfaction Rating (1-5)"
|
| 1385 |
+
)
|
| 1386 |
+
|
| 1387 |
+
# Feedback textbox
|
| 1388 |
+
feedback_box = gr.Textbox(
|
| 1389 |
+
placeholder="Please enter your feedback...",
|
| 1390 |
+
lines=2,
|
| 1391 |
+
label="💬 Feedback"
|
| 1392 |
+
)
|
| 1393 |
+
|
| 1394 |
+
# Submit button
|
| 1395 |
+
submit_feedback_btn = gr.Button("💾 Submit Feedback", variant="secondary")
|
| 1396 |
+
|
| 1397 |
+
feedback_status = gr.Textbox(
|
| 1398 |
+
label="✅ Submission Status",
|
| 1399 |
+
lines=1,
|
| 1400 |
+
visible=False
|
| 1401 |
+
)
|
| 1402 |
+
|
| 1403 |
+
# State for tracking examples
|
| 1404 |
+
track_user_examples = gr.State(example_tracking_zips.copy())
|
| 1405 |
+
|
| 1406 |
+
# Function to get preview image from ZIP
|
| 1407 |
+
def get_zip_preview(zip_path):
|
| 1408 |
+
"""Extract first frame from ZIP for gallery preview"""
|
| 1409 |
+
try:
|
| 1410 |
+
temp_dir = tempfile.mkdtemp()
|
| 1411 |
+
with zipfile.ZipFile(zip_path, 'r') as zip_ref:
|
| 1412 |
+
for member in zip_ref.namelist():
|
| 1413 |
+
basename = os.path.basename(member)
|
| 1414 |
+
if ('__MACOSX' not in member and
|
| 1415 |
+
not basename.startswith('._') and
|
| 1416 |
+
basename.lower().endswith(('.tif', '.tiff', '.png', '.jpg'))):
|
| 1417 |
+
zip_ref.extract(member, temp_dir)
|
| 1418 |
+
extracted_path = os.path.join(temp_dir, member)
|
| 1419 |
+
|
| 1420 |
+
# Load and normalize for preview
|
| 1421 |
+
import tifffile
|
| 1422 |
+
import numpy as np
|
| 1423 |
+
|
| 1424 |
+
img_np = tifffile.imread(extracted_path)
|
| 1425 |
+
if img_np.dtype == np.uint16:
|
| 1426 |
+
img_min, img_max = img_np.min(), img_np.max()
|
| 1427 |
+
if img_max > img_min:
|
| 1428 |
+
img_np = ((img_np.astype(np.float32) - img_min) / (img_max - img_min) * 255).astype(np.uint8)
|
| 1429 |
+
|
| 1430 |
+
if img_np.ndim == 2:
|
| 1431 |
+
img_np = np.stack([img_np]*3, axis=-1)
|
| 1432 |
+
|
| 1433 |
+
# Save preview
|
| 1434 |
+
preview_path = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
| 1435 |
+
Image.fromarray(img_np).save(preview_path.name)
|
| 1436 |
+
return preview_path.name
|
| 1437 |
+
except:
|
| 1438 |
+
pass
|
| 1439 |
+
return None
|
| 1440 |
+
|
| 1441 |
+
# Initialize gallery with previews
|
| 1442 |
+
def init_tracking_gallery():
|
| 1443 |
+
"""Create preview images for ZIP examples"""
|
| 1444 |
+
previews = []
|
| 1445 |
+
for zip_path in example_tracking_zips:
|
| 1446 |
+
if os.path.exists(zip_path):
|
| 1447 |
+
preview = get_zip_preview(zip_path)
|
| 1448 |
+
if preview:
|
| 1449 |
+
previews.append(preview)
|
| 1450 |
+
return previews
|
| 1451 |
+
|
| 1452 |
+
# Load gallery on startup
|
| 1453 |
+
demo.load(
|
| 1454 |
+
fn=init_tracking_gallery,
|
| 1455 |
+
outputs=track_example_gallery
|
| 1456 |
+
)
|
| 1457 |
+
|
| 1458 |
+
# Add ZIP to gallery
|
| 1459 |
+
def add_zip_to_gallery(zip_path, current_zips):
|
| 1460 |
+
if not zip_path:
|
| 1461 |
+
return current_zips, track_example_gallery
|
| 1462 |
+
try:
|
| 1463 |
+
if zip_path not in current_zips:
|
| 1464 |
+
current_zips.append(zip_path)
|
| 1465 |
+
print(f"✅ Added ZIP to gallery: {zip_path}")
|
| 1466 |
+
# Regenerate previews
|
| 1467 |
+
previews = []
|
| 1468 |
+
for zp in current_zips:
|
| 1469 |
+
preview = get_zip_preview(zp)
|
| 1470 |
+
if preview:
|
| 1471 |
+
previews.append(preview)
|
| 1472 |
+
return current_zips, previews
|
| 1473 |
+
except Exception as e:
|
| 1474 |
+
print(f"⚠️ Error: {e}")
|
| 1475 |
+
return current_zips, []
|
| 1476 |
+
|
| 1477 |
+
track_gallery_upload.upload(
|
| 1478 |
+
fn=add_zip_to_gallery,
|
| 1479 |
+
inputs=[track_gallery_upload, track_user_examples],
|
| 1480 |
+
outputs=[track_user_examples, track_example_gallery]
|
| 1481 |
+
)
|
| 1482 |
+
|
| 1483 |
+
# Select ZIP from gallery
|
| 1484 |
+
def load_zip_from_gallery(evt: gr.SelectData, all_zips):
|
| 1485 |
+
if evt.index is not None and evt.index < len(all_zips):
|
| 1486 |
+
selected_zip = all_zips[evt.index]
|
| 1487 |
+
print(f"📁 Selected ZIP from gallery: {selected_zip}")
|
| 1488 |
+
return selected_zip
|
| 1489 |
+
return None
|
| 1490 |
+
|
| 1491 |
+
track_example_gallery.select(
|
| 1492 |
+
fn=load_zip_from_gallery,
|
| 1493 |
+
inputs=track_user_examples,
|
| 1494 |
+
outputs=track_zip_upload
|
| 1495 |
+
)
|
| 1496 |
+
|
| 1497 |
+
# Load first frame when ZIP is uploaded
|
| 1498 |
+
def load_first_frame_for_annotation(zip_file_obj):
|
| 1499 |
+
'''Load and normalize first frame from ZIP for annotation'''
|
| 1500 |
+
if zip_file_obj is None:
|
| 1501 |
+
return None, gr.update(visible=False)
|
| 1502 |
+
|
| 1503 |
+
import tifffile
|
| 1504 |
+
import numpy as np
|
| 1505 |
+
|
| 1506 |
+
try:
|
| 1507 |
+
temp_dir = tempfile.mkdtemp()
|
| 1508 |
+
with zipfile.ZipFile(zip_file_obj.name, 'r') as zip_ref:
|
| 1509 |
+
for member in zip_ref.namelist():
|
| 1510 |
+
basename = os.path.basename(member)
|
| 1511 |
+
if ('__MACOSX' not in member and
|
| 1512 |
+
not basename.startswith('._') and
|
| 1513 |
+
basename.lower().endswith(('.tif', '.tiff'))):
|
| 1514 |
+
zip_ref.extract(member, temp_dir)
|
| 1515 |
+
|
| 1516 |
+
tif_dir = find_valid_tif_dir(temp_dir)
|
| 1517 |
+
if tif_dir:
|
| 1518 |
+
first_frame = extract_first_frame(tif_dir)
|
| 1519 |
+
if first_frame:
|
| 1520 |
+
# Load and normalize the first frame
|
| 1521 |
+
try:
|
| 1522 |
+
img_np = tifffile.imread(first_frame)
|
| 1523 |
+
|
| 1524 |
+
# Normalize to [0, 255] uint8 range for display
|
| 1525 |
+
if img_np.dtype == np.uint8:
|
| 1526 |
+
pass # Already uint8
|
| 1527 |
+
elif img_np.dtype == np.uint16:
|
| 1528 |
+
# Normalize uint16 using actual min/max
|
| 1529 |
+
img_min, img_max = img_np.min(), img_np.max()
|
| 1530 |
+
if img_max > img_min:
|
| 1531 |
+
img_np = ((img_np.astype(np.float32) - img_min) / (img_max - img_min) * 255).astype(np.uint8)
|
| 1532 |
+
else:
|
| 1533 |
+
img_np = (img_np.astype(np.float32) / 65535.0 * 255).astype(np.uint8)
|
| 1534 |
+
else:
|
| 1535 |
+
# Float or other types
|
| 1536 |
+
img_np = img_np.astype(np.float32)
|
| 1537 |
+
img_min, img_max = img_np.min(), img_np.max()
|
| 1538 |
+
if img_max > img_min:
|
| 1539 |
+
img_np = ((img_np - img_min) / (img_max - img_min) * 255).astype(np.uint8)
|
| 1540 |
+
else:
|
| 1541 |
+
img_np = np.clip(img_np * 255, 0, 255).astype(np.uint8)
|
| 1542 |
+
|
| 1543 |
+
# Convert to RGB if grayscale
|
| 1544 |
+
if img_np.ndim == 2:
|
| 1545 |
+
img_np = np.stack([img_np]*3, axis=-1)
|
| 1546 |
+
elif img_np.ndim == 3 and img_np.shape[2] > 3:
|
| 1547 |
+
img_np = img_np[:, :, :3]
|
| 1548 |
+
|
| 1549 |
+
# Save normalized image to temp file
|
| 1550 |
+
temp_img = tempfile.NamedTemporaryFile(delete=False, suffix=".png")
|
| 1551 |
+
Image.fromarray(img_np).save(temp_img.name)
|
| 1552 |
+
|
| 1553 |
+
print(f"✅ Loaded and normalized first frame: {first_frame}")
|
| 1554 |
+
print(f" Original dtype: {tifffile.imread(first_frame).dtype}")
|
| 1555 |
+
print(f" Normalized to uint8 RGB for annotation")
|
| 1556 |
+
|
| 1557 |
+
return temp_img.name, gr.update(visible=True)
|
| 1558 |
+
except Exception as e:
|
| 1559 |
+
print(f"⚠️ Error normalizing first frame: {e}")
|
| 1560 |
+
import traceback
|
| 1561 |
+
traceback.print_exc()
|
| 1562 |
+
# Fallback to original file
|
| 1563 |
+
return first_frame, gr.update(visible=True)
|
| 1564 |
+
except Exception as e:
|
| 1565 |
+
print(f"⚠️ Error loading first frame: {e}")
|
| 1566 |
+
import traceback
|
| 1567 |
+
traceback.print_exc()
|
| 1568 |
+
return None, gr.update(visible=False)
|
| 1569 |
+
|
| 1570 |
+
# Load first frame when ZIP is uploaded
|
| 1571 |
+
track_zip_upload.change(
|
| 1572 |
+
fn=load_first_frame_for_annotation,
|
| 1573 |
+
inputs=track_zip_upload,
|
| 1574 |
+
outputs=[track_first_frame_annotator, track_first_frame_annotator]
|
| 1575 |
+
)
|
| 1576 |
+
|
| 1577 |
+
# Run tracking
|
| 1578 |
+
track_btn.click(
|
| 1579 |
+
fn=track_video_handler,
|
| 1580 |
+
inputs=[track_use_box_radio, track_first_frame_annotator, track_zip_upload],
|
| 1581 |
+
outputs=[track_download, track_output, track_download, track_first_frame_preview]
|
| 1582 |
+
)
|
| 1583 |
+
|
| 1584 |
+
# 清空按钮事件
|
| 1585 |
+
clear_btn.click(
|
| 1586 |
+
fn=lambda: None,
|
| 1587 |
+
inputs=None,
|
| 1588 |
+
outputs=track_first_frame_annotator
|
| 1589 |
+
)
|
| 1590 |
+
|
| 1591 |
+
# 绑定事件: 提交反馈
|
| 1592 |
+
def submit_user_feedback(query_id, score, comment, annot_val):
|
| 1593 |
+
try:
|
| 1594 |
+
img_path = annot_val[0] if annot_val and len(annot_val) > 0 else None
|
| 1595 |
+
bboxes = annot_val[1] if annot_val and len(annot_val) > 1 else []
|
| 1596 |
+
|
| 1597 |
+
# save_feedback(
|
| 1598 |
+
# query_id=query_id,
|
| 1599 |
+
# feedback_type=f"score_{int(score)}",
|
| 1600 |
+
# feedback_text=comment,
|
| 1601 |
+
# img_path=img_path,
|
| 1602 |
+
# bboxes=bboxes
|
| 1603 |
+
# )
|
| 1604 |
+
# 使用 HF 存储
|
| 1605 |
+
save_feedback_to_hf(
|
| 1606 |
+
query_id=query_id,
|
| 1607 |
+
feedback_type=f"score_{int(score)}",
|
| 1608 |
+
feedback_text=comment,
|
| 1609 |
+
img_path=img_path,
|
| 1610 |
+
bboxes=bboxes
|
| 1611 |
+
)
|
| 1612 |
+
return "✅ Feedback submitted successfully, thank you!", gr.update(visible=True)
|
| 1613 |
+
except Exception as e:
|
| 1614 |
+
return f"❌ Submission failed: {str(e)}", gr.update(visible=True)
|
| 1615 |
+
|
| 1616 |
+
submit_feedback_btn.click(
|
| 1617 |
+
fn=submit_user_feedback,
|
| 1618 |
+
inputs=[current_query_id, score_slider, feedback_box, annotator],
|
| 1619 |
+
outputs=[feedback_status, feedback_status]
|
| 1620 |
+
)
|
| 1621 |
+
|
| 1622 |
+
gr.Markdown(
|
| 1623 |
+
"""
|
| 1624 |
+
---
|
| 1625 |
+
### 💡 Technical Details
|
| 1626 |
+
|
| 1627 |
+
**MicroscopyMatching** - A general-purpose microscopy image analysis toolkit based on Stable Diffusion
|
| 1628 |
+
"""
|
| 1629 |
+
)
|
| 1630 |
+
|
| 1631 |
+
if __name__ == "__main__":
|
| 1632 |
+
demo.queue().launch(
|
| 1633 |
+
server_name="0.0.0.0",
|
| 1634 |
+
server_port=7860,
|
| 1635 |
+
share=False,
|
| 1636 |
+
ssr_mode=False,
|
| 1637 |
+
show_error=True,
|
| 1638 |
+
)
|
config.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass, field
|
| 2 |
+
from pathlib import Path
|
| 3 |
+
from typing import Dict, List
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class RunConfig:
|
| 8 |
+
# Guiding text prompt
|
| 9 |
+
prompt: str = "<task-prompt>"
|
| 10 |
+
# Whether to use Stable Diffusion v2.1
|
| 11 |
+
sd_2_1: bool = False
|
| 12 |
+
# Which token indices to alter with attend-and-excite
|
| 13 |
+
token_indices: List[int] = field(default_factory=lambda: [2,5])
|
| 14 |
+
# Which random seeds to use when generating
|
| 15 |
+
seeds: List[int] = field(default_factory=lambda: [42])
|
| 16 |
+
# Path to save all outputs to
|
| 17 |
+
output_path: Path = Path('./outputs')
|
| 18 |
+
# Number of denoising steps
|
| 19 |
+
n_inference_steps: int = 50
|
| 20 |
+
# Text guidance scale
|
| 21 |
+
guidance_scale: float = 7.5
|
| 22 |
+
# Number of denoising steps to apply attend-and-excite
|
| 23 |
+
max_iter_to_alter: int = 25
|
| 24 |
+
# Resolution of UNet to compute attention maps over
|
| 25 |
+
attention_res: int = 16
|
| 26 |
+
# Whether to run standard SD or attend-and-excite
|
| 27 |
+
run_standard_sd: bool = False
|
| 28 |
+
# Dictionary defining the iterations and desired thresholds to apply iterative latent refinement in
|
| 29 |
+
thresholds: Dict[int, float] = field(default_factory=lambda: {0: 0.05, 10: 0.5, 20: 0.8})
|
| 30 |
+
# Scale factor for updating the denoised latent z_t
|
| 31 |
+
scale_factor: int = 20
|
| 32 |
+
# Start and end values used for scaling the scale factor - decays linearly with the denoising timestep
|
| 33 |
+
scale_range: tuple = field(default_factory=lambda: (1.0, 0.5))
|
| 34 |
+
# Whether to apply the Gaussian smoothing before computing the maximum attention value for each subject token
|
| 35 |
+
smooth_attentions: bool = True
|
| 36 |
+
# Standard deviation for the Gaussian smoothing
|
| 37 |
+
sigma: float = 0.5
|
| 38 |
+
# Kernel size for the Gaussian smoothing
|
| 39 |
+
kernel_size: int = 3
|
| 40 |
+
# Whether to save cross attention maps for the final results
|
| 41 |
+
save_cross_attention_maps: bool = False
|
| 42 |
+
|
| 43 |
+
def __post_init__(self):
|
| 44 |
+
self.output_path.mkdir(exist_ok=True, parents=True)
|
counting.py
ADDED
|
@@ -0,0 +1,340 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# stable diffusion x loca
|
| 2 |
+
import os
|
| 3 |
+
import pprint
|
| 4 |
+
from typing import Any, List, Optional
|
| 5 |
+
import argparse
|
| 6 |
+
from huggingface_hub import hf_hub_download
|
| 7 |
+
import pyrallis
|
| 8 |
+
from pytorch_lightning.utilities.types import STEP_OUTPUT
|
| 9 |
+
import torch
|
| 10 |
+
import os
|
| 11 |
+
from PIL import Image
|
| 12 |
+
import numpy as np
|
| 13 |
+
from config import RunConfig
|
| 14 |
+
from _utils import attn_utils_new as attn_utils
|
| 15 |
+
from _utils.attn_utils import AttentionStore
|
| 16 |
+
from _utils.misc_helper import *
|
| 17 |
+
import torch.nn.functional as F
|
| 18 |
+
import matplotlib.pyplot as plt
|
| 19 |
+
import cv2
|
| 20 |
+
import warnings
|
| 21 |
+
from pytorch_lightning.callbacks import ModelCheckpoint
|
| 22 |
+
warnings.filterwarnings("ignore", category=UserWarning)
|
| 23 |
+
import pytorch_lightning as pl
|
| 24 |
+
from _utils.load_models import load_stable_diffusion_model
|
| 25 |
+
from models.model import Counting_with_SD_features_loca as Counting
|
| 26 |
+
from pytorch_lightning.loggers import WandbLogger
|
| 27 |
+
from models.enc_model.loca_args import get_argparser as loca_get_argparser
|
| 28 |
+
from models.enc_model.loca import build_model as build_loca_model
|
| 29 |
+
import time
|
| 30 |
+
import torchvision.transforms as T
|
| 31 |
+
import skimage.io as io
|
| 32 |
+
|
| 33 |
+
SCALE = 1
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class CountingModule(pl.LightningModule):
|
| 37 |
+
def __init__(self, use_box=True):
|
| 38 |
+
super().__init__()
|
| 39 |
+
self.use_box = use_box
|
| 40 |
+
self.config = RunConfig() # config for stable diffusion
|
| 41 |
+
self.initialize_model()
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def initialize_model(self):
|
| 45 |
+
|
| 46 |
+
# load loca model
|
| 47 |
+
loca_args = loca_get_argparser().parse_args()
|
| 48 |
+
self.loca_model = build_loca_model(loca_args)
|
| 49 |
+
# weights = torch.load("ckpt/loca_few_shot.pt")["model"]
|
| 50 |
+
# weights = {k.replace("module","") : v for k, v in weights.items()}
|
| 51 |
+
# self.loca_model.load_state_dict(weights, strict=False)
|
| 52 |
+
# del weights
|
| 53 |
+
|
| 54 |
+
self.counting_adapter = Counting(scale_factor=SCALE)
|
| 55 |
+
# if os.path.isfile(self.args.adapter_weight):
|
| 56 |
+
# adapter_weight = torch.load(self.args.adapter_weight,map_location=torch.device('cpu'))
|
| 57 |
+
# self.counting_adapter.load_state_dict(adapter_weight, strict=False)
|
| 58 |
+
|
| 59 |
+
### load stable diffusion and its controller
|
| 60 |
+
self.stable = load_stable_diffusion_model(config=self.config)
|
| 61 |
+
self.noise_scheduler = self.stable.scheduler
|
| 62 |
+
self.controller = AttentionStore(max_size=64)
|
| 63 |
+
attn_utils.register_attention_control(self.stable, self.controller)
|
| 64 |
+
attn_utils.register_hier_output(self.stable)
|
| 65 |
+
|
| 66 |
+
##### initialize token_emb #####
|
| 67 |
+
placeholder_token = "<task-prompt>"
|
| 68 |
+
self.task_token = "repetitive objects"
|
| 69 |
+
# Add the placeholder token in tokenizer
|
| 70 |
+
num_added_tokens = self.stable.tokenizer.add_tokens(placeholder_token)
|
| 71 |
+
if num_added_tokens == 0:
|
| 72 |
+
raise ValueError(
|
| 73 |
+
f"The tokenizer already contains the token {placeholder_token}. Please pass a different"
|
| 74 |
+
" `placeholder_token` that is not already in the tokenizer."
|
| 75 |
+
)
|
| 76 |
+
try:
|
| 77 |
+
task_embed_from_pretrain = hf_hub_download(
|
| 78 |
+
repo_id="phoebe777777/111",
|
| 79 |
+
filename="task_embed.pth",
|
| 80 |
+
token=None,
|
| 81 |
+
force_download=False
|
| 82 |
+
)
|
| 83 |
+
placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token)
|
| 84 |
+
self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer))
|
| 85 |
+
|
| 86 |
+
token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data
|
| 87 |
+
token_embeds[placeholder_token_id] = task_embed_from_pretrain
|
| 88 |
+
except:
|
| 89 |
+
initializer_token = "count"
|
| 90 |
+
token_ids = self.stable.tokenizer.encode(initializer_token, add_special_tokens=False)
|
| 91 |
+
# Check if initializer_token is a single token or a sequence of tokens
|
| 92 |
+
if len(token_ids) > 1:
|
| 93 |
+
raise ValueError("The initializer token must be a single token.")
|
| 94 |
+
|
| 95 |
+
initializer_token_id = token_ids[0]
|
| 96 |
+
placeholder_token_id = self.stable.tokenizer.convert_tokens_to_ids(placeholder_token)
|
| 97 |
+
|
| 98 |
+
self.stable.text_encoder.resize_token_embeddings(len(self.stable.tokenizer))
|
| 99 |
+
|
| 100 |
+
token_embeds = self.stable.text_encoder.get_input_embeddings().weight.data
|
| 101 |
+
token_embeds[placeholder_token_id] = token_embeds[initializer_token_id]
|
| 102 |
+
|
| 103 |
+
# others
|
| 104 |
+
self.placeholder_token = placeholder_token
|
| 105 |
+
self.placeholder_token_id = placeholder_token_id
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def move_to_device(self, device):
|
| 109 |
+
self.stable.to(device)
|
| 110 |
+
if self.loca_model is not None and self.counting_adapter is not None:
|
| 111 |
+
self.loca_model.to(device)
|
| 112 |
+
self.counting_adapter.to(device)
|
| 113 |
+
self.to(device)
|
| 114 |
+
|
| 115 |
+
def forward(self, data_path, box=None):
|
| 116 |
+
filename = data_path.split("/")[-1]
|
| 117 |
+
img = Image.open(data_path).convert("RGB")
|
| 118 |
+
width, height = img.size
|
| 119 |
+
input_image = T.Compose([T.ToTensor(), T.Resize((512, 512))])(img)
|
| 120 |
+
input_image_stable = input_image - 0.5
|
| 121 |
+
input_image = T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])(input_image)
|
| 122 |
+
if box is not None:
|
| 123 |
+
boxes = torch.tensor(box) / torch.tensor([width, height, width, height]) * 512 # xyxy, normalized
|
| 124 |
+
assert self.use_box == True
|
| 125 |
+
else:
|
| 126 |
+
boxes = torch.tensor([[100,100,130,130], [200,200,250,250]], dtype=torch.float32) # dummy box
|
| 127 |
+
assert self.use_box == False
|
| 128 |
+
|
| 129 |
+
# move to device
|
| 130 |
+
input_image = input_image.unsqueeze(0).to(self.device)
|
| 131 |
+
boxes = boxes.unsqueeze(0).to(self.device)
|
| 132 |
+
input_image_stable = input_image_stable.unsqueeze(0).to(self.device)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
latents = self.stable.vae.encode(input_image_stable).latent_dist.sample().detach()
|
| 137 |
+
latents = latents * 0.18215
|
| 138 |
+
# Sample noise that we'll add to the latents
|
| 139 |
+
noise = torch.randn_like(latents)
|
| 140 |
+
timesteps = torch.tensor([20], device=latents.device).long()
|
| 141 |
+
noisy_latents = self.noise_scheduler.add_noise(latents, noise, timesteps)
|
| 142 |
+
input_ids_ = self.stable.tokenizer(
|
| 143 |
+
self.placeholder_token + " repetitive objects",
|
| 144 |
+
# "object",
|
| 145 |
+
padding="max_length",
|
| 146 |
+
truncation=True,
|
| 147 |
+
max_length=self.stable.tokenizer.model_max_length,
|
| 148 |
+
return_tensors="pt",
|
| 149 |
+
)
|
| 150 |
+
input_ids = input_ids_["input_ids"].to(self.device)
|
| 151 |
+
attention_mask = input_ids_["attention_mask"].to(self.device)
|
| 152 |
+
encoder_hidden_states = self.stable.text_encoder(input_ids, attention_mask)[0]
|
| 153 |
+
|
| 154 |
+
input_image = input_image.to(self.device)
|
| 155 |
+
boxes = boxes.to(self.device)
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
task_loc_idx = torch.nonzero(input_ids == self.placeholder_token_id)
|
| 159 |
+
if self.use_box:
|
| 160 |
+
loca_out = self.loca_model.forward_before_reg(input_image, boxes)
|
| 161 |
+
loca_feature_bf_regression = loca_out["feature_bf_regression"]
|
| 162 |
+
adapted_emb = self.counting_adapter.adapter(loca_feature_bf_regression, boxes) # shape [1, 768]
|
| 163 |
+
if task_loc_idx.shape[0] == 0:
|
| 164 |
+
encoder_hidden_states[0,2,:] = adapted_emb.squeeze() # 放在task prompt下一位
|
| 165 |
+
else:
|
| 166 |
+
encoder_hidden_states[0,task_loc_idx[0, 1]+1,:] = adapted_emb.squeeze() # 放在task prompt下一位
|
| 167 |
+
|
| 168 |
+
# Predict the noise residual
|
| 169 |
+
noise_pred, feature_list = self.stable.unet(noisy_latents, timesteps, encoder_hidden_states)
|
| 170 |
+
noise_pred = noise_pred.sample
|
| 171 |
+
attention_store = self.controller.attention_store
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
attention_maps = []
|
| 175 |
+
exemplar_attention_maps = []
|
| 176 |
+
exemplar_attention_maps1 = []
|
| 177 |
+
exemplar_attention_maps2 = []
|
| 178 |
+
exemplar_attention_maps3 = []
|
| 179 |
+
|
| 180 |
+
cross_self_task_attn_maps = []
|
| 181 |
+
cross_self_exe_attn_maps = []
|
| 182 |
+
|
| 183 |
+
# only use 64x64 self-attention
|
| 184 |
+
self_attn_aggregate = attn_utils.aggregate_attention( # [res, res, 4096]
|
| 185 |
+
prompts=[self.config.prompt], # 这里要改么
|
| 186 |
+
attention_store=self.controller,
|
| 187 |
+
res=64,
|
| 188 |
+
from_where=("up", "down"),
|
| 189 |
+
is_cross=False,
|
| 190 |
+
select=0
|
| 191 |
+
)
|
| 192 |
+
self_attn_aggregate32 = attn_utils.aggregate_attention( # [res, res, 4096]
|
| 193 |
+
prompts=[self.config.prompt], # 这里要改么
|
| 194 |
+
attention_store=self.controller,
|
| 195 |
+
res=32,
|
| 196 |
+
from_where=("up", "down"),
|
| 197 |
+
is_cross=False,
|
| 198 |
+
select=0
|
| 199 |
+
)
|
| 200 |
+
self_attn_aggregate16 = attn_utils.aggregate_attention( # [res, res, 4096]
|
| 201 |
+
prompts=[self.config.prompt], # 这里要改么
|
| 202 |
+
attention_store=self.controller,
|
| 203 |
+
res=16,
|
| 204 |
+
from_where=("up", "down"),
|
| 205 |
+
is_cross=False,
|
| 206 |
+
select=0
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# cross attention
|
| 210 |
+
for res in [32, 16]:
|
| 211 |
+
attn_aggregate = attn_utils.aggregate_attention( # [res, res, 77]
|
| 212 |
+
prompts=[self.config.prompt], # 这里要改么
|
| 213 |
+
attention_store=self.controller,
|
| 214 |
+
res=res,
|
| 215 |
+
from_where=("up", "down"),
|
| 216 |
+
is_cross=True,
|
| 217 |
+
select=0
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
task_attn_ = attn_aggregate[:, :, 1].unsqueeze(0).unsqueeze(0) # [1, 1, res, res]
|
| 221 |
+
attention_maps.append(task_attn_)
|
| 222 |
+
if self.use_box:
|
| 223 |
+
exemplar_attns = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0) # 取exemplar的attn
|
| 224 |
+
exemplar_attention_maps.append(exemplar_attns)
|
| 225 |
+
else:
|
| 226 |
+
exemplar_attns1 = attn_aggregate[:, :, 2].unsqueeze(0).unsqueeze(0)
|
| 227 |
+
exemplar_attns2 = attn_aggregate[:, :, 3].unsqueeze(0).unsqueeze(0)
|
| 228 |
+
exemplar_attns3 = attn_aggregate[:, :, 4].unsqueeze(0).unsqueeze(0)
|
| 229 |
+
exemplar_attention_maps1.append(exemplar_attns1)
|
| 230 |
+
exemplar_attention_maps2.append(exemplar_attns2)
|
| 231 |
+
exemplar_attention_maps3.append(exemplar_attns3)
|
| 232 |
+
|
| 233 |
+
|
| 234 |
+
scale_factors = [(64 // attention_maps[i].shape[-1]) for i in range(len(attention_maps))]
|
| 235 |
+
attns = torch.cat([F.interpolate(attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(attention_maps))])
|
| 236 |
+
task_attn_64 = torch.mean(attns, dim=0, keepdim=True)
|
| 237 |
+
cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64)
|
| 238 |
+
cross_self_task_attn_maps.append(cross_self_task_attn)
|
| 239 |
+
|
| 240 |
+
if self.use_box:
|
| 241 |
+
scale_factors = [(64 // exemplar_attention_maps[i].shape[-1]) for i in range(len(exemplar_attention_maps))]
|
| 242 |
+
attns = torch.cat([F.interpolate(exemplar_attention_maps[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps))])
|
| 243 |
+
exemplar_attn_64 = torch.mean(attns, dim=0, keepdim=True)
|
| 244 |
+
|
| 245 |
+
cross_self_exe_attn = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64)
|
| 246 |
+
cross_self_exe_attn_maps.append(cross_self_exe_attn)
|
| 247 |
+
else:
|
| 248 |
+
scale_factors = [(64 // exemplar_attention_maps1[i].shape[-1]) for i in range(len(exemplar_attention_maps1))]
|
| 249 |
+
attns = torch.cat([F.interpolate(exemplar_attention_maps1[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps1))])
|
| 250 |
+
exemplar_attn_64_1 = torch.mean(attns, dim=0, keepdim=True)
|
| 251 |
+
|
| 252 |
+
scale_factors = [(64 // exemplar_attention_maps2[i].shape[-1]) for i in range(len(exemplar_attention_maps2))]
|
| 253 |
+
attns = torch.cat([F.interpolate(exemplar_attention_maps2[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps2))])
|
| 254 |
+
exemplar_attn_64_2 = torch.mean(attns, dim=0, keepdim=True)
|
| 255 |
+
|
| 256 |
+
scale_factors = [(64 // exemplar_attention_maps3[i].shape[-1]) for i in range(len(exemplar_attention_maps3))]
|
| 257 |
+
attns = torch.cat([F.interpolate(exemplar_attention_maps3[i_], scale_factor=scale_factors[i_], mode="bilinear") for i_ in range(len(exemplar_attention_maps3))])
|
| 258 |
+
exemplar_attn_64_3 = torch.mean(attns, dim=0, keepdim=True)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
cross_self_task_attn = attn_utils.self_cross_attn(self_attn_aggregate, task_attn_64)
|
| 262 |
+
cross_self_task_attn_maps.append(cross_self_task_attn)
|
| 263 |
+
|
| 264 |
+
# if self.args.merge_exemplar == "average":
|
| 265 |
+
cross_self_exe_attn1 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_1)
|
| 266 |
+
cross_self_exe_attn2 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_2)
|
| 267 |
+
cross_self_exe_attn3 = attn_utils.self_cross_attn(self_attn_aggregate, exemplar_attn_64_3)
|
| 268 |
+
exemplar_attn_64 = (exemplar_attn_64_1 + exemplar_attn_64_2 + exemplar_attn_64_3) / 3
|
| 269 |
+
cross_self_exe_attn = (cross_self_exe_attn1 + cross_self_exe_attn2 + cross_self_exe_attn3) / 3
|
| 270 |
+
|
| 271 |
+
exemplar_attn_64 = (exemplar_attn_64 - exemplar_attn_64.min()) / (exemplar_attn_64.max() - exemplar_attn_64.min() + 1e-6)
|
| 272 |
+
|
| 273 |
+
attn_stack = [exemplar_attn_64 / 2, cross_self_exe_attn / 2, exemplar_attn_64, cross_self_exe_attn]
|
| 274 |
+
attn_stack = torch.cat(attn_stack, dim=1)
|
| 275 |
+
|
| 276 |
+
if not self.use_box:
|
| 277 |
+
|
| 278 |
+
# cross_self_exe_attn_np = cross_self_exe_attn.detach().squeeze().cpu().numpy()
|
| 279 |
+
# boxes = gen_dummy_boxes(cross_self_exe_attn_np, max_boxes=1)
|
| 280 |
+
# boxes = boxes.to(self.device)
|
| 281 |
+
|
| 282 |
+
loca_out = self.loca_model.forward_before_reg(input_image, boxes)
|
| 283 |
+
loca_feature_bf_regression = loca_out["feature_bf_regression"]
|
| 284 |
+
attn_out = self.loca_model.forward_reg(loca_out, attn_stack, feature_list[-1])
|
| 285 |
+
pred_density = attn_out["pred"].squeeze().cpu().numpy()
|
| 286 |
+
pred_cnt = pred_density.sum().item()
|
| 287 |
+
|
| 288 |
+
# resize pred_density to original image size
|
| 289 |
+
pred_density_rsz = cv2.resize(pred_density, (width, height), interpolation=cv2.INTER_CUBIC)
|
| 290 |
+
pred_density_rsz = pred_density_rsz / pred_density_rsz.sum() * pred_cnt
|
| 291 |
+
|
| 292 |
+
return pred_density_rsz, pred_cnt
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
def inference(data_path, box=None, save_path="./example_imgs", visualize=False):
|
| 296 |
+
if box is not None:
|
| 297 |
+
use_box = True
|
| 298 |
+
else:
|
| 299 |
+
use_box = False
|
| 300 |
+
model = CountingModule(use_box=use_box)
|
| 301 |
+
load_msg = model.load_state_dict(torch.load("pretrained/microscopy_matching_cnt.pth"), strict=True)
|
| 302 |
+
model.eval()
|
| 303 |
+
with torch.no_grad():
|
| 304 |
+
density_map, cnt = model(data_path, box)
|
| 305 |
+
|
| 306 |
+
if visualize:
|
| 307 |
+
img = io.imread(data_path)
|
| 308 |
+
if len(img.shape) == 3 and img.shape[2] > 3:
|
| 309 |
+
img = img[:,:,:3]
|
| 310 |
+
if len(img.shape) == 2:
|
| 311 |
+
img = np.stack([img]*3, axis=-1)
|
| 312 |
+
img_show = img.squeeze()
|
| 313 |
+
density_map_show = density_map.squeeze()
|
| 314 |
+
os.makedirs(save_path, exist_ok=True)
|
| 315 |
+
filename = data_path.split("/")[-1]
|
| 316 |
+
img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show))
|
| 317 |
+
fig, ax = plt.subplots(1,2, figsize=(12,6))
|
| 318 |
+
ax[0].imshow(img_show)
|
| 319 |
+
ax[0].axis('off')
|
| 320 |
+
ax[0].set_title(f"Input image")
|
| 321 |
+
ax[1].imshow(img_show)
|
| 322 |
+
ax[1].imshow(density_map_show, cmap='jet', alpha=0.5) # Overlay density map with some transparency
|
| 323 |
+
ax[1].axis('off')
|
| 324 |
+
ax[1].set_title(f"Predicted density map, count: {cnt:.1f}")
|
| 325 |
+
plt.tight_layout()
|
| 326 |
+
plt.savefig(os.path.join(save_path, filename.split(".")[0]+"_cnt.png"), dpi=300)
|
| 327 |
+
plt.close()
|
| 328 |
+
return density_map
|
| 329 |
+
|
| 330 |
+
def main():
|
| 331 |
+
|
| 332 |
+
inference(
|
| 333 |
+
data_path = "example_imgs/1977_Well_F-5_Field_1.png",
|
| 334 |
+
# box=[[150, 60, 183, 87]],
|
| 335 |
+
save_path = "./example_imgs",
|
| 336 |
+
visualize = True
|
| 337 |
+
)
|
| 338 |
+
|
| 339 |
+
if __name__ == "__main__":
|
| 340 |
+
main()
|
example_imgs/cnt/047cell.png
ADDED
|
Git LFS Details
|
example_imgs/cnt/62_10.png
ADDED
|
Git LFS Details
|
example_imgs/cnt/6800-17000_GTEX-XQ3S_Adipose-Subcutaneous.png
ADDED
|
Git LFS Details
|
example_imgs/seg/003_img.png
ADDED
|
Git LFS Details
|
example_imgs/seg/1-23 [Scan I08].png
ADDED
|
Git LFS Details
|
example_imgs/seg/10X_B2_Tile-15.aligned.png
ADDED
|
Git LFS Details
|
example_imgs/seg/1977_Well_F-5_Field_1.png
ADDED
|
Git LFS Details
|
example_imgs/seg/200972823[5179]_RhoGGG_YAP_TAZ [200972823 Well K6 Field #2].png
ADDED
|
Git LFS Details
|
example_imgs/seg/A172_Phase_C7_1_00d00h00m_1.png
ADDED
|
Git LFS Details
|
example_imgs/seg/JE2NileRed_oilp22_PMP_101220_011_NR.png
ADDED
|
Git LFS Details
|
example_imgs/seg/OpenTest_031.png
ADDED
|
Git LFS Details
|
example_imgs/seg/X_24.png
ADDED
|
Git LFS Details
|
example_imgs/seg/exp_A01_G002_0001.oir.png
ADDED
|
Git LFS Details
|
example_imgs/tra/tracking_test_sequence.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:bda69434e3de8103c98313777640acd35fc7501eec4b1528456304142b18797f
|
| 3 |
+
size 10392163
|
example_imgs/tra/tracking_test_sequence2.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:120cc2a75a4dd571b8f8ee7ea363a9b82a2b4c516376ccf4f287b6864d2dd576
|
| 3 |
+
size 2288296
|
inference_count.py
ADDED
|
@@ -0,0 +1,237 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference_count.py
|
| 2 |
+
# 计数模型推理模块 - 独立版本
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
from PIL import Image
|
| 7 |
+
import matplotlib.pyplot as plt
|
| 8 |
+
import tempfile
|
| 9 |
+
import os
|
| 10 |
+
from huggingface_hub import hf_hub_download
|
| 11 |
+
from counting import CountingModule
|
| 12 |
+
|
| 13 |
+
MODEL = None
|
| 14 |
+
DEVICE = torch.device("cpu")
|
| 15 |
+
|
| 16 |
+
def load_model(use_box=False):
|
| 17 |
+
"""
|
| 18 |
+
加载计数模型
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
use_box: 是否使用边界框
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
model: 加载的模型
|
| 25 |
+
device: 设备
|
| 26 |
+
"""
|
| 27 |
+
global MODEL, DEVICE
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
print("🔄 Loading counting model...")
|
| 31 |
+
|
| 32 |
+
# 初始化模型
|
| 33 |
+
MODEL = CountingModule(use_box=use_box)
|
| 34 |
+
|
| 35 |
+
# 从 Hugging Face Hub 下载权重
|
| 36 |
+
ckpt_path = hf_hub_download(
|
| 37 |
+
repo_id="phoebe777777/111",
|
| 38 |
+
filename="microscopy_matching_cnt.pth",
|
| 39 |
+
token=None,
|
| 40 |
+
force_download=False
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
print(f"✅ Checkpoint downloaded: {ckpt_path}")
|
| 44 |
+
|
| 45 |
+
# 加载权重
|
| 46 |
+
MODEL.load_state_dict(
|
| 47 |
+
torch.load(ckpt_path, map_location="cpu"),
|
| 48 |
+
strict=True
|
| 49 |
+
)
|
| 50 |
+
MODEL.eval()
|
| 51 |
+
|
| 52 |
+
if torch.cuda.is_available():
|
| 53 |
+
DEVICE = torch.device("cuda")
|
| 54 |
+
MODEL.move_to_device(DEVICE)
|
| 55 |
+
print("✅ Model moved to CUDA")
|
| 56 |
+
else:
|
| 57 |
+
DEVICE = torch.device("cpu")
|
| 58 |
+
MODEL.move_to_device(DEVICE)
|
| 59 |
+
print("✅ Model on CPU")
|
| 60 |
+
|
| 61 |
+
print("✅ Counting model loaded successfully")
|
| 62 |
+
return MODEL, DEVICE
|
| 63 |
+
|
| 64 |
+
except Exception as e:
|
| 65 |
+
print(f"❌ Error loading counting model: {e}")
|
| 66 |
+
import traceback
|
| 67 |
+
traceback.print_exc()
|
| 68 |
+
return None, torch.device("cpu")
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
@torch.no_grad()
|
| 72 |
+
def run(model, img_path, box=None, device="cpu", visualize=True):
|
| 73 |
+
"""
|
| 74 |
+
运行计数推理
|
| 75 |
+
|
| 76 |
+
Args:
|
| 77 |
+
model: 计数模型
|
| 78 |
+
img_path: 图像路径
|
| 79 |
+
box: 边界框 [[x1, y1, x2, y2], ...] 或 None
|
| 80 |
+
device: 设备
|
| 81 |
+
visualize: 是否生成可视化
|
| 82 |
+
|
| 83 |
+
Returns:
|
| 84 |
+
result_dict: {
|
| 85 |
+
'density_map': numpy array,
|
| 86 |
+
'count': float,
|
| 87 |
+
'visualized_path': str (如果 visualize=True)
|
| 88 |
+
}
|
| 89 |
+
"""
|
| 90 |
+
print("DEVICE:", device)
|
| 91 |
+
model.move_to_device(device)
|
| 92 |
+
model.eval()
|
| 93 |
+
if box is not None:
|
| 94 |
+
use_box = True
|
| 95 |
+
else:
|
| 96 |
+
use_box = False
|
| 97 |
+
model.use_box = use_box
|
| 98 |
+
|
| 99 |
+
if model is None:
|
| 100 |
+
return {
|
| 101 |
+
'density_map': None,
|
| 102 |
+
'count': 0,
|
| 103 |
+
'visualized_path': None,
|
| 104 |
+
'error': 'Model not loaded'
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
try:
|
| 108 |
+
print(f"🔄 Running counting inference on {img_path}")
|
| 109 |
+
|
| 110 |
+
# 运行推理 (调用你的模型的 forward 方法)
|
| 111 |
+
with torch.no_grad():
|
| 112 |
+
density_map, count = model(img_path, box)
|
| 113 |
+
|
| 114 |
+
print(f"✅ Counting result: {count:.1f} objects")
|
| 115 |
+
|
| 116 |
+
result = {
|
| 117 |
+
'density_map': density_map,
|
| 118 |
+
'count': count,
|
| 119 |
+
'visualized_path': None
|
| 120 |
+
}
|
| 121 |
+
|
| 122 |
+
# 可视化
|
| 123 |
+
# if visualize:
|
| 124 |
+
# viz_path = visualize_result(img_path, density_map, count)
|
| 125 |
+
# result['visualized_path'] = viz_path
|
| 126 |
+
|
| 127 |
+
return result
|
| 128 |
+
|
| 129 |
+
except Exception as e:
|
| 130 |
+
print(f"❌ Counting inference error: {e}")
|
| 131 |
+
import traceback
|
| 132 |
+
traceback.print_exc()
|
| 133 |
+
return {
|
| 134 |
+
'density_map': None,
|
| 135 |
+
'count': 0,
|
| 136 |
+
'visualized_path': None,
|
| 137 |
+
'error': str(e)
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def visualize_result(image_path, density_map, count):
|
| 142 |
+
"""
|
| 143 |
+
可视化计数结果 (与你原来的可视化代码一致)
|
| 144 |
+
|
| 145 |
+
Args:
|
| 146 |
+
image_path: 原始图像路径
|
| 147 |
+
density_map: 密度图 (numpy array)
|
| 148 |
+
count: 计数值
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
output_path: 可视化结果的临时文件路径
|
| 152 |
+
"""
|
| 153 |
+
try:
|
| 154 |
+
import skimage.io as io
|
| 155 |
+
|
| 156 |
+
# 读取原始图像
|
| 157 |
+
img = io.imread(image_path)
|
| 158 |
+
|
| 159 |
+
# 处理不同格式的图像
|
| 160 |
+
if len(img.shape) == 3 and img.shape[2] > 3:
|
| 161 |
+
img = img[:, :, :3]
|
| 162 |
+
if len(img.shape) == 2:
|
| 163 |
+
img = np.stack([img]*3, axis=-1)
|
| 164 |
+
|
| 165 |
+
# 归一化显示
|
| 166 |
+
img_show = img.squeeze()
|
| 167 |
+
density_map_show = density_map.squeeze()
|
| 168 |
+
|
| 169 |
+
# 归一化图像
|
| 170 |
+
img_show = (img_show - np.min(img_show)) / (np.max(img_show) - np.min(img_show) + 1e-8)
|
| 171 |
+
|
| 172 |
+
# 创建可视化 (与你原来的代码一致)
|
| 173 |
+
fig, ax = plt.subplots(figsize=(8, 6))
|
| 174 |
+
|
| 175 |
+
# 右图: 密度图叠加
|
| 176 |
+
ax.imshow(img_show)
|
| 177 |
+
ax.imshow(density_map_show, cmap='jet', alpha=0.5)
|
| 178 |
+
ax.axis('off')
|
| 179 |
+
# ax.set_title(f"Predicted density map, count: {count:.1f}")
|
| 180 |
+
|
| 181 |
+
plt.tight_layout()
|
| 182 |
+
|
| 183 |
+
# 保存到临时文件
|
| 184 |
+
temp_file = tempfile.NamedTemporaryFile(delete=False, suffix='.png')
|
| 185 |
+
plt.savefig(temp_file.name, dpi=300)
|
| 186 |
+
plt.close()
|
| 187 |
+
|
| 188 |
+
print(f"✅ Visualization saved to {temp_file.name}")
|
| 189 |
+
return temp_file.name
|
| 190 |
+
|
| 191 |
+
except Exception as e:
|
| 192 |
+
print(f"❌ Visualization error: {e}")
|
| 193 |
+
import traceback
|
| 194 |
+
traceback.print_exc()
|
| 195 |
+
return image_path
|
| 196 |
+
|
| 197 |
+
|
| 198 |
+
# ===== 测试代码 =====
|
| 199 |
+
if __name__ == "__main__":
|
| 200 |
+
print("="*60)
|
| 201 |
+
print("Testing Counting Model")
|
| 202 |
+
print("="*60)
|
| 203 |
+
|
| 204 |
+
# 测试模型加载
|
| 205 |
+
model, device = load_model(use_box=False)
|
| 206 |
+
|
| 207 |
+
if model is not None:
|
| 208 |
+
print("\n" + "="*60)
|
| 209 |
+
print("Model loaded successfully, testing inference...")
|
| 210 |
+
print("="*60)
|
| 211 |
+
|
| 212 |
+
# 测试推理
|
| 213 |
+
test_image = "example_imgs/1977_Well_F-5_Field_1.png"
|
| 214 |
+
|
| 215 |
+
if os.path.exists(test_image):
|
| 216 |
+
result = run(
|
| 217 |
+
model,
|
| 218 |
+
test_image,
|
| 219 |
+
box=None,
|
| 220 |
+
device=device,
|
| 221 |
+
visualize=True
|
| 222 |
+
)
|
| 223 |
+
|
| 224 |
+
if 'error' not in result:
|
| 225 |
+
print("\n" + "="*60)
|
| 226 |
+
print("Inference Results:")
|
| 227 |
+
print("="*60)
|
| 228 |
+
print(f"Count: {result['count']:.1f}")
|
| 229 |
+
print(f"Density map shape: {result['density_map'].shape}")
|
| 230 |
+
if result['visualized_path']:
|
| 231 |
+
print(f"Visualization saved to: {result['visualized_path']}")
|
| 232 |
+
else:
|
| 233 |
+
print(f"\n❌ Inference failed: {result['error']}")
|
| 234 |
+
else:
|
| 235 |
+
print(f"\n⚠️ Test image not found: {test_image}")
|
| 236 |
+
else:
|
| 237 |
+
print("\n❌ Model loading failed")
|
inference_seg.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from huggingface_hub import hf_hub_download
|
| 4 |
+
from segmentation import SegmentationModule
|
| 5 |
+
|
| 6 |
+
MODEL = None
|
| 7 |
+
DEVICE = torch.device("cpu")
|
| 8 |
+
|
| 9 |
+
def load_model(use_box=False):
|
| 10 |
+
global MODEL, DEVICE
|
| 11 |
+
MODEL = SegmentationModule(use_box=use_box)
|
| 12 |
+
|
| 13 |
+
ckpt_path = hf_hub_download(
|
| 14 |
+
repo_id="phoebe777777/111",
|
| 15 |
+
filename="microscopy_matching_seg.pth",
|
| 16 |
+
token=None,
|
| 17 |
+
force_download=False
|
| 18 |
+
)
|
| 19 |
+
MODEL.load_state_dict(torch.load(ckpt_path, map_location="cpu"), strict=False)
|
| 20 |
+
MODEL.eval()
|
| 21 |
+
if torch.cuda.is_available():
|
| 22 |
+
DEVICE = torch.device("cuda")
|
| 23 |
+
MODEL.move_to_device(DEVICE)
|
| 24 |
+
print("✅ Model moved to CUDA")
|
| 25 |
+
else:
|
| 26 |
+
DEVICE = torch.device("cpu")
|
| 27 |
+
MODEL.move_to_device(DEVICE)
|
| 28 |
+
print("✅ Model on CPU")
|
| 29 |
+
return MODEL, DEVICE
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@torch.no_grad()
|
| 33 |
+
def run(model, img_path, box=None, device="cpu"):
|
| 34 |
+
print("DEVICE:", device)
|
| 35 |
+
model.move_to_device(device)
|
| 36 |
+
model.eval()
|
| 37 |
+
with torch.no_grad():
|
| 38 |
+
if box is not None:
|
| 39 |
+
use_box = True
|
| 40 |
+
else:
|
| 41 |
+
use_box = False
|
| 42 |
+
model.use_box = use_box
|
| 43 |
+
output = model(img_path, box=box)
|
| 44 |
+
mask = output
|
| 45 |
+
return mask
|
| 46 |
+
# import os
|
| 47 |
+
# import torch
|
| 48 |
+
# import numpy as np
|
| 49 |
+
# from huggingface_hub import hf_hub_download
|
| 50 |
+
# from segmentation import SegmentationModule
|
| 51 |
+
|
| 52 |
+
# MODEL = None
|
| 53 |
+
# DEVICE = torch.device("cpu")
|
| 54 |
+
|
| 55 |
+
# def load_model(use_box=False):
|
| 56 |
+
# global MODEL, DEVICE
|
| 57 |
+
|
| 58 |
+
# # === 优化1: 使用 /data 缓存模型,避免写入 .cache ===
|
| 59 |
+
# cache_dir = "/data/cellseg_model_cache"
|
| 60 |
+
# os.makedirs(cache_dir, exist_ok=True)
|
| 61 |
+
|
| 62 |
+
# ckpt_path = hf_hub_download(
|
| 63 |
+
# repo_id="Shengxiao0709/cellsegmodel",
|
| 64 |
+
# filename="microscopy_matching_seg.pth",
|
| 65 |
+
# token=None,
|
| 66 |
+
# local_dir=cache_dir, # ✅ 下载到 /data
|
| 67 |
+
# local_dir_use_symlinks=False, # ✅ 避免软链接问题
|
| 68 |
+
# force_download=False # ✅ 已存在时不重复下载
|
| 69 |
+
# )
|
| 70 |
+
|
| 71 |
+
# # === 优化2: 加载模型 ===
|
| 72 |
+
# MODEL = SegmentationModule(use_box=use_box)
|
| 73 |
+
# state_dict = torch.load(ckpt_path, map_location="cpu")
|
| 74 |
+
# MODEL.load_state_dict(state_dict, strict=False)
|
| 75 |
+
# MODEL.eval()
|
| 76 |
+
|
| 77 |
+
# DEVICE = torch.device("cpu")
|
| 78 |
+
# print(f"✅ Model loaded from {ckpt_path}")
|
| 79 |
+
# return MODEL, DEVICE
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
# @torch.no_grad()
|
| 83 |
+
# def run(model, img_path, box=None, device="cpu"):
|
| 84 |
+
# output = model(img_path, box=box)
|
| 85 |
+
# mask = output["pred"]
|
| 86 |
+
# mask = (mask > 0).astype(np.uint8)
|
| 87 |
+
# return mask
|
inference_track.py
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# inference_track.py
|
| 2 |
+
# 视频跟踪模型推理模块
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import numpy as np
|
| 6 |
+
import os
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from tqdm import tqdm
|
| 9 |
+
from huggingface_hub import hf_hub_download
|
| 10 |
+
from tracking_one import TrackingModule
|
| 11 |
+
from models.tra_post_model.trackastra.tracking import graph_to_ctc
|
| 12 |
+
|
| 13 |
+
MODEL = None
|
| 14 |
+
DEVICE = torch.device("cpu")
|
| 15 |
+
|
| 16 |
+
def load_model(use_box=False):
|
| 17 |
+
"""
|
| 18 |
+
加载跟踪模型
|
| 19 |
+
|
| 20 |
+
Args:
|
| 21 |
+
use_box: 是否使用边界框
|
| 22 |
+
|
| 23 |
+
Returns:
|
| 24 |
+
model: 加载的模型
|
| 25 |
+
device: 设备
|
| 26 |
+
"""
|
| 27 |
+
global MODEL, DEVICE
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
print("🔄 Loading tracking model...")
|
| 31 |
+
|
| 32 |
+
# 初始化模型
|
| 33 |
+
MODEL = TrackingModule(use_box=use_box)
|
| 34 |
+
|
| 35 |
+
# 从 Hugging Face Hub 下载权重
|
| 36 |
+
ckpt_path = hf_hub_download(
|
| 37 |
+
repo_id="phoebe777777/111",
|
| 38 |
+
filename="microscopy_matching_tra.pth",
|
| 39 |
+
token=None,
|
| 40 |
+
force_download=False
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
print(f"✅ Checkpoint downloaded: {ckpt_path}")
|
| 44 |
+
|
| 45 |
+
# 加载权重
|
| 46 |
+
MODEL.load_state_dict(
|
| 47 |
+
torch.load(ckpt_path, map_location="cpu"),
|
| 48 |
+
strict=True
|
| 49 |
+
)
|
| 50 |
+
MODEL.eval()
|
| 51 |
+
|
| 52 |
+
# 设置设备
|
| 53 |
+
if torch.cuda.is_available():
|
| 54 |
+
DEVICE = torch.device("cuda")
|
| 55 |
+
MODEL.move_to_device(DEVICE)
|
| 56 |
+
print("✅ Model moved to CUDA")
|
| 57 |
+
else:
|
| 58 |
+
DEVICE = torch.device("cpu")
|
| 59 |
+
MODEL.move_to_device(DEVICE)
|
| 60 |
+
print("✅ Model on CPU")
|
| 61 |
+
|
| 62 |
+
print("✅ Tracking model loaded successfully")
|
| 63 |
+
return MODEL, DEVICE
|
| 64 |
+
|
| 65 |
+
except Exception as e:
|
| 66 |
+
print(f"❌ Error loading tracking model: {e}")
|
| 67 |
+
import traceback
|
| 68 |
+
traceback.print_exc()
|
| 69 |
+
return None, torch.device("cpu")
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@torch.no_grad()
|
| 73 |
+
def run(model, video_dir, box=None, device="cpu", output_dir="tracked_results"):
|
| 74 |
+
"""
|
| 75 |
+
运行视频跟踪推理
|
| 76 |
+
|
| 77 |
+
Args:
|
| 78 |
+
model: 跟踪模型
|
| 79 |
+
video_dir: 视频帧序列目录 (包含连续的图像文件)
|
| 80 |
+
box: 边界框 (可选)
|
| 81 |
+
device: 设备
|
| 82 |
+
output_dir: 输出目录
|
| 83 |
+
|
| 84 |
+
Returns:
|
| 85 |
+
result_dict: {
|
| 86 |
+
'track_graph': TrackGraph对象,
|
| 87 |
+
'masks': 分割掩码数组 (T, H, W),
|
| 88 |
+
'output_dir': 输出目录路径,
|
| 89 |
+
'num_tracks': 跟踪轨迹数量
|
| 90 |
+
}
|
| 91 |
+
"""
|
| 92 |
+
if model is None:
|
| 93 |
+
return {
|
| 94 |
+
'track_graph': None,
|
| 95 |
+
'masks': None,
|
| 96 |
+
'output_dir': None,
|
| 97 |
+
'num_tracks': 0,
|
| 98 |
+
'error': 'Model not loaded'
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
print(f"🔄 Running tracking inference on {video_dir}")
|
| 103 |
+
|
| 104 |
+
# 运行跟踪
|
| 105 |
+
track_graph, masks = model.track(
|
| 106 |
+
file_dir=video_dir,
|
| 107 |
+
boxes=box,
|
| 108 |
+
mode="greedy", # 可选: "greedy", "greedy_nodiv", "ilp"
|
| 109 |
+
dataname="tracking_result"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# 创建输出目录
|
| 113 |
+
if not os.path.exists(output_dir):
|
| 114 |
+
os.makedirs(output_dir)
|
| 115 |
+
|
| 116 |
+
# 转换为CTC格式并保存
|
| 117 |
+
print("🔄 Converting to CTC format...")
|
| 118 |
+
ctc_tracks, masks_tracked = graph_to_ctc(
|
| 119 |
+
track_graph,
|
| 120 |
+
masks,
|
| 121 |
+
outdir=output_dir,
|
| 122 |
+
)
|
| 123 |
+
print(f"✅ CTC results saved to {output_dir}")
|
| 124 |
+
|
| 125 |
+
# num_tracks = len(track_graph.tracks())
|
| 126 |
+
|
| 127 |
+
print(f"✅ Tracking completed")
|
| 128 |
+
|
| 129 |
+
result = {
|
| 130 |
+
'track_graph': track_graph,
|
| 131 |
+
'masks': masks,
|
| 132 |
+
'masks_tracked': masks_tracked,
|
| 133 |
+
'output_dir': output_dir,
|
| 134 |
+
# 'num_tracks': num_tracks
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
return result
|
| 138 |
+
|
| 139 |
+
except Exception as e:
|
| 140 |
+
print(f"❌ Tracking inference error: {e}")
|
| 141 |
+
import traceback
|
| 142 |
+
traceback.print_exc()
|
| 143 |
+
return {
|
| 144 |
+
'track_graph': None,
|
| 145 |
+
'masks': None,
|
| 146 |
+
'output_dir': None,
|
| 147 |
+
'num_tracks': 0,
|
| 148 |
+
'error': str(e)
|
| 149 |
+
}
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def visualize_tracking_result(masks_tracked, output_path):
|
| 153 |
+
"""
|
| 154 |
+
可视化跟踪结果 (可选)
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
masks_tracked: 跟踪后的掩码 (T, H, W)
|
| 158 |
+
output_path: 输出视频路径
|
| 159 |
+
|
| 160 |
+
Returns:
|
| 161 |
+
output_path: 视频文件路径
|
| 162 |
+
"""
|
| 163 |
+
try:
|
| 164 |
+
import cv2
|
| 165 |
+
import matplotlib.pyplot as plt
|
| 166 |
+
from matplotlib import cm
|
| 167 |
+
|
| 168 |
+
# 获取时间帧数
|
| 169 |
+
T, H, W = masks_tracked.shape
|
| 170 |
+
|
| 171 |
+
# 创建颜色映射
|
| 172 |
+
unique_ids = np.unique(masks_tracked)
|
| 173 |
+
num_colors = len(unique_ids)
|
| 174 |
+
cmap = cm.get_cmap('tab20', num_colors)
|
| 175 |
+
|
| 176 |
+
# 创建视频写入器
|
| 177 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
| 178 |
+
out = cv2.VideoWriter(output_path, fourcc, 5.0, (W, H))
|
| 179 |
+
|
| 180 |
+
for t in range(T):
|
| 181 |
+
frame = masks_tracked[t]
|
| 182 |
+
|
| 183 |
+
# 创建彩色图像
|
| 184 |
+
colored_frame = np.zeros((H, W, 3), dtype=np.uint8)
|
| 185 |
+
for i, obj_id in enumerate(unique_ids):
|
| 186 |
+
if obj_id == 0:
|
| 187 |
+
continue
|
| 188 |
+
mask = (frame == obj_id)
|
| 189 |
+
color = np.array(cmap(i % num_colors)[:3]) * 255
|
| 190 |
+
colored_frame[mask] = color
|
| 191 |
+
|
| 192 |
+
# 转换为BGR (OpenCV格式)
|
| 193 |
+
colored_frame_bgr = cv2.cvtColor(colored_frame, cv2.COLOR_RGB2BGR)
|
| 194 |
+
out.write(colored_frame_bgr)
|
| 195 |
+
|
| 196 |
+
out.release()
|
| 197 |
+
print(f"✅ Visualization saved to {output_path}")
|
| 198 |
+
return output_path
|
| 199 |
+
|
| 200 |
+
except Exception as e:
|
| 201 |
+
print(f"❌ Visualization error: {e}")
|
| 202 |
+
return None
|
models/.DS_Store
ADDED
|
Binary file (6.15 kB). View file
|
|
|
models/enc_model/__init__.py
ADDED
|
File without changes
|
models/enc_model/backbone.py
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
from torch.nn import functional as F
|
| 4 |
+
from torchvision import models
|
| 5 |
+
from torchvision.ops.misc import FrozenBatchNorm2d
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Backbone(nn.Module):
|
| 9 |
+
|
| 10 |
+
def __init__(
|
| 11 |
+
self,
|
| 12 |
+
name: str,
|
| 13 |
+
pretrained: bool,
|
| 14 |
+
dilation: bool,
|
| 15 |
+
reduction: int,
|
| 16 |
+
swav: bool,
|
| 17 |
+
requires_grad: bool
|
| 18 |
+
):
|
| 19 |
+
|
| 20 |
+
super(Backbone, self).__init__()
|
| 21 |
+
|
| 22 |
+
resnet = getattr(models, name)(
|
| 23 |
+
replace_stride_with_dilation=[False, False, dilation],
|
| 24 |
+
pretrained=pretrained, norm_layer=FrozenBatchNorm2d
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
self.backbone = resnet
|
| 28 |
+
self.reduction = reduction
|
| 29 |
+
|
| 30 |
+
if name == 'resnet50' and swav:
|
| 31 |
+
checkpoint = torch.hub.load_state_dict_from_url(
|
| 32 |
+
'https://dl.fbaipublicfiles.com/deepcluster/swav_800ep_pretrain.pth.tar',
|
| 33 |
+
map_location="cpu"
|
| 34 |
+
)
|
| 35 |
+
state_dict = {k.replace("module.", ""): v for k, v in checkpoint.items()}
|
| 36 |
+
self.backbone.load_state_dict(state_dict, strict=False)
|
| 37 |
+
|
| 38 |
+
# concatenation of layers 2, 3 and 4
|
| 39 |
+
self.num_channels = 896 if name in ['resnet18', 'resnet34'] else 3584
|
| 40 |
+
|
| 41 |
+
for n, param in self.backbone.named_parameters():
|
| 42 |
+
if 'layer2' not in n and 'layer3' not in n and 'layer4' not in n:
|
| 43 |
+
param.requires_grad_(False)
|
| 44 |
+
else:
|
| 45 |
+
param.requires_grad_(requires_grad)
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
size = x.size(-2) // self.reduction, x.size(-1) // self.reduction
|
| 49 |
+
x = self.backbone.conv1(x)
|
| 50 |
+
x = self.backbone.bn1(x)
|
| 51 |
+
x = self.backbone.relu(x)
|
| 52 |
+
x = self.backbone.maxpool(x)
|
| 53 |
+
|
| 54 |
+
x = self.backbone.layer1(x)
|
| 55 |
+
x = layer2 = self.backbone.layer2(x)
|
| 56 |
+
x = layer3 = self.backbone.layer3(x)
|
| 57 |
+
x = layer4 = self.backbone.layer4(x)
|
| 58 |
+
|
| 59 |
+
x = torch.cat([
|
| 60 |
+
F.interpolate(f, size=size, mode='bilinear', align_corners=True)
|
| 61 |
+
for f in [layer2, layer3, layer4]
|
| 62 |
+
], dim=1)
|
| 63 |
+
|
| 64 |
+
return x
|
models/enc_model/loca.py
ADDED
|
@@ -0,0 +1,232 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .backbone import Backbone
|
| 2 |
+
from .transformer import TransformerEncoder
|
| 3 |
+
from .ope import OPEModule
|
| 4 |
+
from .positional_encoding import PositionalEncodingsFixed
|
| 5 |
+
from .regression_head import DensityMapRegressor
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from torch import nn
|
| 9 |
+
from torch.nn import functional as F
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LOCA(nn.Module):
|
| 13 |
+
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
image_size: int,
|
| 17 |
+
num_encoder_layers: int,
|
| 18 |
+
num_ope_iterative_steps: int,
|
| 19 |
+
num_objects: int,
|
| 20 |
+
emb_dim: int,
|
| 21 |
+
num_heads: int,
|
| 22 |
+
kernel_dim: int,
|
| 23 |
+
backbone_name: str,
|
| 24 |
+
swav_backbone: bool,
|
| 25 |
+
train_backbone: bool,
|
| 26 |
+
reduction: int,
|
| 27 |
+
dropout: float,
|
| 28 |
+
layer_norm_eps: float,
|
| 29 |
+
mlp_factor: int,
|
| 30 |
+
norm_first: bool,
|
| 31 |
+
activation: nn.Module,
|
| 32 |
+
norm: bool,
|
| 33 |
+
zero_shot: bool,
|
| 34 |
+
):
|
| 35 |
+
|
| 36 |
+
super(LOCA, self).__init__()
|
| 37 |
+
|
| 38 |
+
self.emb_dim = emb_dim
|
| 39 |
+
self.num_objects = num_objects
|
| 40 |
+
self.reduction = reduction
|
| 41 |
+
self.kernel_dim = kernel_dim
|
| 42 |
+
self.image_size = image_size
|
| 43 |
+
self.zero_shot = zero_shot
|
| 44 |
+
self.num_heads = num_heads
|
| 45 |
+
self.num_encoder_layers = num_encoder_layers
|
| 46 |
+
|
| 47 |
+
self.backbone = Backbone(
|
| 48 |
+
backbone_name, pretrained=True, dilation=False, reduction=reduction,
|
| 49 |
+
swav=swav_backbone, requires_grad=train_backbone
|
| 50 |
+
)
|
| 51 |
+
self.input_proj = nn.Conv2d(
|
| 52 |
+
self.backbone.num_channels, emb_dim, kernel_size=1
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if num_encoder_layers > 0:
|
| 56 |
+
self.encoder = TransformerEncoder(
|
| 57 |
+
num_encoder_layers, emb_dim, num_heads, dropout, layer_norm_eps,
|
| 58 |
+
mlp_factor, norm_first, activation, norm
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
self.ope = OPEModule(
|
| 62 |
+
num_ope_iterative_steps, emb_dim, kernel_dim, num_objects, num_heads,
|
| 63 |
+
reduction, layer_norm_eps, mlp_factor, norm_first, activation, norm, zero_shot
|
| 64 |
+
)
|
| 65 |
+
|
| 66 |
+
self.regression_head = DensityMapRegressor(emb_dim, reduction)
|
| 67 |
+
self.aux_heads = nn.ModuleList([
|
| 68 |
+
DensityMapRegressor(emb_dim, reduction)
|
| 69 |
+
for _ in range(num_ope_iterative_steps - 1)
|
| 70 |
+
])
|
| 71 |
+
|
| 72 |
+
self.pos_emb = PositionalEncodingsFixed(emb_dim)
|
| 73 |
+
|
| 74 |
+
self.attn_norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 75 |
+
self.fuse = nn.Sequential(
|
| 76 |
+
nn.Conv2d(324, 256, kernel_size=1, stride=1),
|
| 77 |
+
nn.LeakyReLU(),
|
| 78 |
+
nn.LayerNorm((64, 64))
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# self.fuse1 = nn.Sequential(
|
| 82 |
+
# nn.Conv2d(322, 256, kernel_size=1, stride=1),
|
| 83 |
+
# nn.LeakyReLU(),
|
| 84 |
+
# nn.LayerNorm((64, 64))
|
| 85 |
+
# )
|
| 86 |
+
|
| 87 |
+
def forward_before_reg(self, x, bboxes):
|
| 88 |
+
num_objects = bboxes.size(1) if not self.zero_shot else self.num_objects
|
| 89 |
+
# backbone
|
| 90 |
+
backbone_features = self.backbone(x)
|
| 91 |
+
# prepare the encoder input
|
| 92 |
+
src = self.input_proj(backbone_features)
|
| 93 |
+
bs, c, h, w = src.size()
|
| 94 |
+
pos_emb = self.pos_emb(bs, h, w, src.device).flatten(2).permute(2, 0, 1)
|
| 95 |
+
src = src.flatten(2).permute(2, 0, 1)
|
| 96 |
+
|
| 97 |
+
# push through the encoder
|
| 98 |
+
if self.num_encoder_layers > 0:
|
| 99 |
+
image_features = self.encoder(src, pos_emb, src_key_padding_mask=None, src_mask=None)
|
| 100 |
+
else:
|
| 101 |
+
image_features = src
|
| 102 |
+
|
| 103 |
+
# prepare OPE input
|
| 104 |
+
f_e = image_features.permute(1, 2, 0).reshape(-1, self.emb_dim, h, w)
|
| 105 |
+
|
| 106 |
+
all_prototypes = self.ope(f_e, pos_emb, bboxes) # [3, 27, 1, 256]
|
| 107 |
+
|
| 108 |
+
outputs = list()
|
| 109 |
+
response_maps_list = []
|
| 110 |
+
for i in range(all_prototypes.size(0)):
|
| 111 |
+
prototypes = all_prototypes[i, ...].permute(1, 0, 2).reshape(
|
| 112 |
+
bs, num_objects, self.kernel_dim, self.kernel_dim, -1
|
| 113 |
+
).permute(0, 1, 4, 2, 3).flatten(0, 2)[:, None, ...] # [768, 1, 3, 3]
|
| 114 |
+
|
| 115 |
+
response_maps = F.conv2d(
|
| 116 |
+
torch.cat([f_e for _ in range(num_objects)], dim=1).flatten(0, 1).unsqueeze(0),
|
| 117 |
+
prototypes,
|
| 118 |
+
bias=None,
|
| 119 |
+
padding=self.kernel_dim // 2,
|
| 120 |
+
groups=prototypes.size(0)
|
| 121 |
+
).view(
|
| 122 |
+
bs, num_objects, self.emb_dim, h, w
|
| 123 |
+
).max(dim=1)[0]
|
| 124 |
+
|
| 125 |
+
# # send through regression heads
|
| 126 |
+
# if i == all_prototypes.size(0) - 1:
|
| 127 |
+
# predicted_dmaps = self.regression_head(response_maps)
|
| 128 |
+
# else:
|
| 129 |
+
# predicted_dmaps = self.aux_heads[i](response_maps)
|
| 130 |
+
# outputs.append(predicted_dmaps)
|
| 131 |
+
response_maps_list.append(response_maps)
|
| 132 |
+
|
| 133 |
+
out = {
|
| 134 |
+
# "pred": outputs[-1],
|
| 135 |
+
"feature_bf_regression": response_maps_list[-1],
|
| 136 |
+
# "aux_pred": outputs[:-1],
|
| 137 |
+
"aux_feature_bf_regression": response_maps_list[:-1]
|
| 138 |
+
}
|
| 139 |
+
|
| 140 |
+
return out
|
| 141 |
+
|
| 142 |
+
def forward_reg(self, response_maps, attn_stack, unet_feature):
|
| 143 |
+
attn_stack = self.attn_norm(attn_stack)
|
| 144 |
+
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 145 |
+
unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 146 |
+
unet_feature = unet_feature * attn_stack_mean
|
| 147 |
+
if unet_feature.shape[1] == 322:
|
| 148 |
+
unet_feature = self.fuse1(unet_feature)
|
| 149 |
+
else:
|
| 150 |
+
unet_feature = self.fuse(unet_feature)
|
| 151 |
+
|
| 152 |
+
response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
|
| 153 |
+
|
| 154 |
+
outputs = []
|
| 155 |
+
for i in range(len(response_maps)):
|
| 156 |
+
response_map = response_maps[i] + unet_feature
|
| 157 |
+
if i == len(response_maps) - 1:
|
| 158 |
+
predicted_dmaps = self.regression_head(response_map)
|
| 159 |
+
else:
|
| 160 |
+
predicted_dmaps = self.aux_heads[i](response_map)
|
| 161 |
+
outputs.append(predicted_dmaps)
|
| 162 |
+
|
| 163 |
+
return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
|
| 164 |
+
|
| 165 |
+
def forward_reg1(self, response_maps, self_attn):
|
| 166 |
+
# attn_stack = self.attn_norm(attn_stack)
|
| 167 |
+
# attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 168 |
+
# unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 169 |
+
# unet_feature = unet_feature * attn_stack_mean
|
| 170 |
+
# if unet_feature.shape[1] == 322:
|
| 171 |
+
# unet_feature = self.fuse1(unet_feature)
|
| 172 |
+
# else:
|
| 173 |
+
# unet_feature = self.fuse(unet_feature)
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
|
| 178 |
+
|
| 179 |
+
outputs = []
|
| 180 |
+
for i in range(len(response_maps)):
|
| 181 |
+
response_map = response_maps[i] + self_attn
|
| 182 |
+
if i == len(response_maps) - 1:
|
| 183 |
+
predicted_dmaps = self.regression_head(response_map)
|
| 184 |
+
else:
|
| 185 |
+
predicted_dmaps = self.aux_heads[i](response_map)
|
| 186 |
+
outputs.append(predicted_dmaps)
|
| 187 |
+
|
| 188 |
+
return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
|
| 189 |
+
|
| 190 |
+
def forward_reg_without_unet(self, response_maps, attn_stack):
|
| 191 |
+
# attn_stack = self.attn_norm(attn_stack)
|
| 192 |
+
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 193 |
+
|
| 194 |
+
response_maps = response_maps["aux_feature_bf_regression"] + [response_maps["feature_bf_regression"]]
|
| 195 |
+
|
| 196 |
+
outputs = []
|
| 197 |
+
for i in range(len(response_maps)):
|
| 198 |
+
response_map = response_maps[i] * attn_stack_mean * 0.5 + response_maps[i]
|
| 199 |
+
if i == len(response_maps) - 1:
|
| 200 |
+
predicted_dmaps = self.regression_head(response_map)
|
| 201 |
+
else:
|
| 202 |
+
predicted_dmaps = self.aux_heads[i](response_map)
|
| 203 |
+
outputs.append(predicted_dmaps)
|
| 204 |
+
|
| 205 |
+
return {"pred": outputs[-1], "aux_pred": outputs[:-1]}
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def build_model(args):
|
| 209 |
+
|
| 210 |
+
assert args.backbone in ['resnet18', 'resnet50', 'resnet101']
|
| 211 |
+
assert args.reduction in [4, 8, 16]
|
| 212 |
+
|
| 213 |
+
return LOCA(
|
| 214 |
+
image_size=args.image_size,
|
| 215 |
+
num_encoder_layers=args.num_enc_layers,
|
| 216 |
+
num_ope_iterative_steps=args.num_ope_iterative_steps,
|
| 217 |
+
num_objects=args.num_objects,
|
| 218 |
+
zero_shot=args.zero_shot,
|
| 219 |
+
emb_dim=args.emb_dim,
|
| 220 |
+
num_heads=args.num_heads,
|
| 221 |
+
kernel_dim=args.kernel_dim,
|
| 222 |
+
backbone_name=args.backbone,
|
| 223 |
+
swav_backbone=args.swav_backbone,
|
| 224 |
+
train_backbone=args.backbone_lr > 0,
|
| 225 |
+
reduction=args.reduction,
|
| 226 |
+
dropout=args.dropout,
|
| 227 |
+
layer_norm_eps=1e-5,
|
| 228 |
+
mlp_factor=8,
|
| 229 |
+
norm_first=args.pre_norm,
|
| 230 |
+
activation=nn.GELU,
|
| 231 |
+
norm=True,
|
| 232 |
+
)
|
models/enc_model/loca_args.py
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
def get_argparser():
|
| 5 |
+
|
| 6 |
+
parser = argparse.ArgumentParser("LOCA parser", add_help=False)
|
| 7 |
+
|
| 8 |
+
parser.add_argument('--model_name', default='loca_few_shot', type=str)
|
| 9 |
+
parser.add_argument(
|
| 10 |
+
'--data_path',
|
| 11 |
+
default='./data/FSC147_384_V2',
|
| 12 |
+
type=str
|
| 13 |
+
)
|
| 14 |
+
parser.add_argument(
|
| 15 |
+
'--model_path',
|
| 16 |
+
default='ckpt',
|
| 17 |
+
type=str
|
| 18 |
+
)
|
| 19 |
+
parser.add_argument('--backbone', default='resnet50', type=str)
|
| 20 |
+
parser.add_argument('--swav_backbone', action='store_true', default=True)
|
| 21 |
+
parser.add_argument('--reduction', default=8, type=int)
|
| 22 |
+
parser.add_argument('--image_size', default=512, type=int)
|
| 23 |
+
parser.add_argument('--num_enc_layers', default=3, type=int)
|
| 24 |
+
parser.add_argument('--num_ope_iterative_steps', default=3, type=int)
|
| 25 |
+
parser.add_argument('--emb_dim', default=256, type=int)
|
| 26 |
+
parser.add_argument('--num_heads', default=8, type=int)
|
| 27 |
+
parser.add_argument('--kernel_dim', default=3, type=int)
|
| 28 |
+
parser.add_argument('--num_objects', default=3, type=int)
|
| 29 |
+
parser.add_argument('--epochs', default=200, type=int)
|
| 30 |
+
parser.add_argument('--resume_training', action='store_true')
|
| 31 |
+
parser.add_argument('--lr', default=1e-4, type=float)
|
| 32 |
+
parser.add_argument('--backbone_lr', default=0, type=float)
|
| 33 |
+
parser.add_argument('--lr_drop', default=200, type=int)
|
| 34 |
+
parser.add_argument('--weight_decay', default=1e-4, type=float)
|
| 35 |
+
parser.add_argument('--batch_size', default=1, type=int)
|
| 36 |
+
parser.add_argument('--dropout', default=0.1, type=float)
|
| 37 |
+
parser.add_argument('--num_workers', default=8, type=int)
|
| 38 |
+
parser.add_argument('--max_grad_norm', default=0.1, type=float)
|
| 39 |
+
parser.add_argument('--aux_weight', default=0.3, type=float)
|
| 40 |
+
parser.add_argument('--tiling_p', default=0.5, type=float)
|
| 41 |
+
parser.add_argument('--zero_shot', action='store_true')
|
| 42 |
+
parser.add_argument('--pre_norm', action='store_true', default=True)
|
| 43 |
+
|
| 44 |
+
return parser
|
models/enc_model/mlp.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class MLP(nn.Module):
|
| 5 |
+
|
| 6 |
+
def __init__(
|
| 7 |
+
self,
|
| 8 |
+
input_dim: int,
|
| 9 |
+
hidden_dim: int,
|
| 10 |
+
dropout: float,
|
| 11 |
+
activation: nn.Module
|
| 12 |
+
):
|
| 13 |
+
super(MLP, self).__init__()
|
| 14 |
+
|
| 15 |
+
self.linear1 = nn.Linear(input_dim, hidden_dim)
|
| 16 |
+
self.linear2 = nn.Linear(hidden_dim, input_dim)
|
| 17 |
+
self.dropout = nn.Dropout(dropout)
|
| 18 |
+
self.activation = activation()
|
| 19 |
+
|
| 20 |
+
def forward(self, x):
|
| 21 |
+
return (
|
| 22 |
+
self.linear2(self.dropout(self.activation(self.linear1(x))))
|
| 23 |
+
)
|
models/enc_model/ope.py
ADDED
|
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mlp import MLP
|
| 2 |
+
from .positional_encoding import PositionalEncodingsFixed
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
|
| 7 |
+
from torchvision.ops import roi_align
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class OPEModule(nn.Module):
|
| 11 |
+
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
num_iterative_steps: int,
|
| 15 |
+
emb_dim: int,
|
| 16 |
+
kernel_dim: int,
|
| 17 |
+
num_objects: int,
|
| 18 |
+
num_heads: int,
|
| 19 |
+
reduction: int,
|
| 20 |
+
layer_norm_eps: float,
|
| 21 |
+
mlp_factor: int,
|
| 22 |
+
norm_first: bool,
|
| 23 |
+
activation: nn.Module,
|
| 24 |
+
norm: bool,
|
| 25 |
+
zero_shot: bool,
|
| 26 |
+
):
|
| 27 |
+
|
| 28 |
+
super(OPEModule, self).__init__()
|
| 29 |
+
|
| 30 |
+
self.num_iterative_steps = num_iterative_steps
|
| 31 |
+
self.zero_shot = zero_shot
|
| 32 |
+
self.kernel_dim = kernel_dim
|
| 33 |
+
self.num_objects = num_objects
|
| 34 |
+
self.emb_dim = emb_dim
|
| 35 |
+
self.reduction = reduction
|
| 36 |
+
|
| 37 |
+
if num_iterative_steps > 0:
|
| 38 |
+
self.iterative_adaptation = IterativeAdaptationModule(
|
| 39 |
+
num_layers=num_iterative_steps, emb_dim=emb_dim, num_heads=num_heads,
|
| 40 |
+
dropout=0, layer_norm_eps=layer_norm_eps,
|
| 41 |
+
mlp_factor=mlp_factor, norm_first=norm_first,
|
| 42 |
+
activation=activation, norm=norm,
|
| 43 |
+
zero_shot=zero_shot
|
| 44 |
+
)
|
| 45 |
+
|
| 46 |
+
if not self.zero_shot:
|
| 47 |
+
self.shape_or_objectness = nn.Sequential(
|
| 48 |
+
nn.Linear(2, 64),
|
| 49 |
+
nn.ReLU(),
|
| 50 |
+
nn.Linear(64, emb_dim),
|
| 51 |
+
nn.ReLU(),
|
| 52 |
+
nn.Linear(emb_dim, self.kernel_dim**2 * emb_dim)
|
| 53 |
+
)
|
| 54 |
+
else:
|
| 55 |
+
self.shape_or_objectness = nn.Parameter(
|
| 56 |
+
torch.empty((self.num_objects, self.kernel_dim**2, emb_dim))
|
| 57 |
+
)
|
| 58 |
+
nn.init.normal_(self.shape_or_objectness)
|
| 59 |
+
|
| 60 |
+
self.pos_emb = PositionalEncodingsFixed(emb_dim)
|
| 61 |
+
|
| 62 |
+
def forward(self, f_e, pos_emb, bboxes):
|
| 63 |
+
bs, _, h, w = f_e.size()
|
| 64 |
+
# extract the shape features or objectness
|
| 65 |
+
if not self.zero_shot:
|
| 66 |
+
box_hw = torch.zeros(bboxes.size(0), bboxes.size(1), 2).to(bboxes.device)
|
| 67 |
+
box_hw[:, :, 0] = bboxes[:, :, 2] - bboxes[:, :, 0]
|
| 68 |
+
box_hw[:, :, 1] = bboxes[:, :, 3] - bboxes[:, :, 1]
|
| 69 |
+
shape_or_objectness = self.shape_or_objectness(box_hw).reshape(
|
| 70 |
+
bs, -1, self.kernel_dim ** 2, self.emb_dim
|
| 71 |
+
).flatten(1, 2).transpose(0, 1)
|
| 72 |
+
else:
|
| 73 |
+
shape_or_objectness = self.shape_or_objectness.expand(
|
| 74 |
+
bs, -1, -1, -1
|
| 75 |
+
).flatten(1, 2).transpose(0, 1)
|
| 76 |
+
|
| 77 |
+
# if not zero shot add appearance
|
| 78 |
+
if not self.zero_shot:
|
| 79 |
+
# reshape bboxes into the format suitable for roi_align
|
| 80 |
+
num_of_boxes = bboxes.size(1)
|
| 81 |
+
bboxes = torch.cat([
|
| 82 |
+
torch.arange(
|
| 83 |
+
bs, requires_grad=False
|
| 84 |
+
).to(bboxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 85 |
+
bboxes.flatten(0, 1),
|
| 86 |
+
], dim=1)
|
| 87 |
+
appearance = roi_align(
|
| 88 |
+
f_e,
|
| 89 |
+
boxes=bboxes, output_size=self.kernel_dim,
|
| 90 |
+
spatial_scale=1.0 / self.reduction, aligned=True
|
| 91 |
+
).permute(0, 2, 3, 1).reshape(
|
| 92 |
+
bs, num_of_boxes * self.kernel_dim ** 2, -1
|
| 93 |
+
).transpose(0, 1)
|
| 94 |
+
else:
|
| 95 |
+
num_of_boxes = self.num_objects
|
| 96 |
+
appearance = None
|
| 97 |
+
|
| 98 |
+
query_pos_emb = self.pos_emb(
|
| 99 |
+
bs, self.kernel_dim, self.kernel_dim, f_e.device
|
| 100 |
+
).flatten(2).permute(2, 0, 1).repeat(num_of_boxes, 1, 1)
|
| 101 |
+
|
| 102 |
+
if self.num_iterative_steps > 0:
|
| 103 |
+
memory = f_e.flatten(2).permute(2, 0, 1)
|
| 104 |
+
all_prototypes = self.iterative_adaptation(
|
| 105 |
+
shape_or_objectness, appearance, memory, pos_emb, query_pos_emb
|
| 106 |
+
)
|
| 107 |
+
else:
|
| 108 |
+
if shape_or_objectness is not None and appearance is not None:
|
| 109 |
+
all_prototypes = (shape_or_objectness + appearance).unsqueeze(0)
|
| 110 |
+
else:
|
| 111 |
+
all_prototypes = (
|
| 112 |
+
shape_or_objectness if shape_or_objectness is not None else appearance
|
| 113 |
+
).unsqueeze(0)
|
| 114 |
+
|
| 115 |
+
return all_prototypes
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
|
| 119 |
+
class IterativeAdaptationModule(nn.Module):
|
| 120 |
+
|
| 121 |
+
def __init__(
|
| 122 |
+
self,
|
| 123 |
+
num_layers: int,
|
| 124 |
+
emb_dim: int,
|
| 125 |
+
num_heads: int,
|
| 126 |
+
dropout: float,
|
| 127 |
+
layer_norm_eps: float,
|
| 128 |
+
mlp_factor: int,
|
| 129 |
+
norm_first: bool,
|
| 130 |
+
activation: nn.Module,
|
| 131 |
+
norm: bool,
|
| 132 |
+
zero_shot: bool
|
| 133 |
+
):
|
| 134 |
+
|
| 135 |
+
super(IterativeAdaptationModule, self).__init__()
|
| 136 |
+
|
| 137 |
+
self.layers = nn.ModuleList([
|
| 138 |
+
IterativeAdaptationLayer(
|
| 139 |
+
emb_dim, num_heads, dropout, layer_norm_eps,
|
| 140 |
+
mlp_factor, norm_first, activation, zero_shot
|
| 141 |
+
) for i in range(num_layers)
|
| 142 |
+
])
|
| 143 |
+
|
| 144 |
+
self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
|
| 145 |
+
|
| 146 |
+
def forward(
|
| 147 |
+
self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask=None, memory_mask=None,
|
| 148 |
+
tgt_key_padding_mask=None, memory_key_padding_mask=None
|
| 149 |
+
):
|
| 150 |
+
|
| 151 |
+
output = tgt
|
| 152 |
+
outputs = list()
|
| 153 |
+
for i, layer in enumerate(self.layers):
|
| 154 |
+
output = layer(
|
| 155 |
+
output, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask,
|
| 156 |
+
tgt_key_padding_mask, memory_key_padding_mask
|
| 157 |
+
)
|
| 158 |
+
outputs.append(self.norm(output))
|
| 159 |
+
|
| 160 |
+
return torch.stack(outputs)
|
| 161 |
+
|
| 162 |
+
|
| 163 |
+
class IterativeAdaptationLayer(nn.Module):
|
| 164 |
+
|
| 165 |
+
def __init__(
|
| 166 |
+
self,
|
| 167 |
+
emb_dim: int,
|
| 168 |
+
num_heads: int,
|
| 169 |
+
dropout: float,
|
| 170 |
+
layer_norm_eps: float,
|
| 171 |
+
mlp_factor: int,
|
| 172 |
+
norm_first: bool,
|
| 173 |
+
activation: nn.Module,
|
| 174 |
+
zero_shot: bool
|
| 175 |
+
):
|
| 176 |
+
super(IterativeAdaptationLayer, self).__init__()
|
| 177 |
+
|
| 178 |
+
self.norm_first = norm_first
|
| 179 |
+
self.zero_shot = zero_shot
|
| 180 |
+
|
| 181 |
+
if not self.zero_shot:
|
| 182 |
+
self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
|
| 183 |
+
self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
|
| 184 |
+
self.norm3 = nn.LayerNorm(emb_dim, layer_norm_eps)
|
| 185 |
+
if not self.zero_shot:
|
| 186 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 187 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 188 |
+
self.dropout3 = nn.Dropout(dropout)
|
| 189 |
+
|
| 190 |
+
if not self.zero_shot:
|
| 191 |
+
self.self_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout)
|
| 192 |
+
self.enc_dec_attn = nn.MultiheadAttention(emb_dim, num_heads, dropout)
|
| 193 |
+
|
| 194 |
+
self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
|
| 195 |
+
|
| 196 |
+
def with_emb(self, x, emb):
|
| 197 |
+
return x if emb is None else x + emb
|
| 198 |
+
|
| 199 |
+
def forward(
|
| 200 |
+
self, tgt, appearance, memory, pos_emb, query_pos_emb, tgt_mask, memory_mask,
|
| 201 |
+
tgt_key_padding_mask, memory_key_padding_mask
|
| 202 |
+
):
|
| 203 |
+
if self.norm_first:
|
| 204 |
+
if not self.zero_shot:
|
| 205 |
+
tgt_norm = self.norm1(tgt)
|
| 206 |
+
tgt = tgt + self.dropout1(self.self_attn(
|
| 207 |
+
query=self.with_emb(tgt_norm, query_pos_emb),
|
| 208 |
+
key=self.with_emb(appearance, query_pos_emb),
|
| 209 |
+
value=appearance,
|
| 210 |
+
attn_mask=tgt_mask,
|
| 211 |
+
key_padding_mask=tgt_key_padding_mask
|
| 212 |
+
)[0])
|
| 213 |
+
|
| 214 |
+
tgt_norm = self.norm2(tgt)
|
| 215 |
+
tgt = tgt + self.dropout2(self.enc_dec_attn(
|
| 216 |
+
query=self.with_emb(tgt_norm, query_pos_emb),
|
| 217 |
+
key=memory+pos_emb,
|
| 218 |
+
value=memory,
|
| 219 |
+
attn_mask=memory_mask,
|
| 220 |
+
key_padding_mask=memory_key_padding_mask
|
| 221 |
+
)[0])
|
| 222 |
+
tgt_norm = self.norm3(tgt)
|
| 223 |
+
tgt = tgt + self.dropout3(self.mlp(tgt_norm))
|
| 224 |
+
|
| 225 |
+
else:
|
| 226 |
+
if not self.zero_shot:
|
| 227 |
+
tgt = self.norm1(tgt + self.dropout1(self.self_attn(
|
| 228 |
+
query=self.with_emb(tgt, query_pos_emb),
|
| 229 |
+
key=self.with_emb(appearance),
|
| 230 |
+
value=appearance,
|
| 231 |
+
attn_mask=tgt_mask,
|
| 232 |
+
key_padding_mask=tgt_key_padding_mask
|
| 233 |
+
)[0]))
|
| 234 |
+
|
| 235 |
+
tgt = self.norm2(tgt + self.dropout2(self.enc_dec_attn(
|
| 236 |
+
query=self.with_emb(tgt, query_pos_emb),
|
| 237 |
+
key=memory+pos_emb,
|
| 238 |
+
value=memory,
|
| 239 |
+
attn_mask=memory_mask,
|
| 240 |
+
key_padding_mask=memory_key_padding_mask
|
| 241 |
+
)[0]))
|
| 242 |
+
|
| 243 |
+
tgt = self.norm3(tgt + self.dropout3(self.mlp(tgt)))
|
| 244 |
+
|
| 245 |
+
return tgt
|
models/enc_model/positional_encoding.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import nn
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class PositionalEncodingsFixed(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, emb_dim, temperature=10000):
|
| 8 |
+
|
| 9 |
+
super(PositionalEncodingsFixed, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.emb_dim = emb_dim
|
| 12 |
+
self.temperature = temperature
|
| 13 |
+
|
| 14 |
+
def _1d_pos_enc(self, mask, dim):
|
| 15 |
+
temp = torch.arange(self.emb_dim // 2).float().to(mask.device)
|
| 16 |
+
temp = self.temperature ** (2 * (temp.div(2, rounding_mode='floor')) / self.emb_dim)
|
| 17 |
+
|
| 18 |
+
enc = (~mask).cumsum(dim).float().unsqueeze(-1) / temp
|
| 19 |
+
enc = torch.stack([
|
| 20 |
+
enc[..., 0::2].sin(), enc[..., 1::2].cos()
|
| 21 |
+
], dim=-1).flatten(-2)
|
| 22 |
+
|
| 23 |
+
return enc
|
| 24 |
+
|
| 25 |
+
def forward(self, bs, h, w, device):
|
| 26 |
+
mask = torch.zeros(bs, h, w, dtype=torch.bool, requires_grad=False, device=device)
|
| 27 |
+
x = self._1d_pos_enc(mask, dim=2)
|
| 28 |
+
y = self._1d_pos_enc(mask, dim=1)
|
| 29 |
+
|
| 30 |
+
return torch.cat([y, x], dim=3).permute(0, 3, 1, 2)
|
models/enc_model/regression_head.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class UpsamplingLayer(nn.Module):
|
| 6 |
+
|
| 7 |
+
def __init__(self, in_channels, out_channels, leaky=True):
|
| 8 |
+
|
| 9 |
+
super(UpsamplingLayer, self).__init__()
|
| 10 |
+
|
| 11 |
+
self.layer = nn.Sequential(
|
| 12 |
+
nn.Conv2d(in_channels, out_channels, kernel_size=3, padding=1),
|
| 13 |
+
nn.LeakyReLU() if leaky else nn.ReLU(),
|
| 14 |
+
nn.UpsamplingBilinear2d(scale_factor=2)
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.layer(x)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class DensityMapRegressor(nn.Module):
|
| 22 |
+
|
| 23 |
+
def __init__(self, in_channels, reduction):
|
| 24 |
+
|
| 25 |
+
super(DensityMapRegressor, self).__init__()
|
| 26 |
+
|
| 27 |
+
if reduction == 8:
|
| 28 |
+
self.regressor = nn.Sequential(
|
| 29 |
+
UpsamplingLayer(in_channels, 128),
|
| 30 |
+
UpsamplingLayer(128, 64),
|
| 31 |
+
UpsamplingLayer(64, 32),
|
| 32 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 33 |
+
nn.LeakyReLU()
|
| 34 |
+
)
|
| 35 |
+
elif reduction == 16:
|
| 36 |
+
self.regressor = nn.Sequential(
|
| 37 |
+
UpsamplingLayer(in_channels, 128),
|
| 38 |
+
UpsamplingLayer(128, 64),
|
| 39 |
+
UpsamplingLayer(64, 32),
|
| 40 |
+
UpsamplingLayer(32, 16),
|
| 41 |
+
nn.Conv2d(16, 1, kernel_size=1),
|
| 42 |
+
nn.LeakyReLU()
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
self.reset_parameters()
|
| 46 |
+
|
| 47 |
+
def forward(self, x):
|
| 48 |
+
return self.regressor(x)
|
| 49 |
+
|
| 50 |
+
def reset_parameters(self):
|
| 51 |
+
for module in self.modules():
|
| 52 |
+
if isinstance(module, nn.Conv2d):
|
| 53 |
+
nn.init.normal_(module.weight, std=0.01)
|
| 54 |
+
if module.bias is not None:
|
| 55 |
+
nn.init.constant_(module.bias, 0)
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class DensityMapRegressor_(nn.Module):
|
| 59 |
+
|
| 60 |
+
def __init__(self, in_channels, reduction):
|
| 61 |
+
|
| 62 |
+
super(DensityMapRegressor, self).__init__()
|
| 63 |
+
|
| 64 |
+
if reduction == 8:
|
| 65 |
+
self.regressor = nn.Sequential(
|
| 66 |
+
UpsamplingLayer(in_channels, 128),
|
| 67 |
+
UpsamplingLayer(128, 64),
|
| 68 |
+
UpsamplingLayer(64, 32),
|
| 69 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 70 |
+
nn.LeakyReLU()
|
| 71 |
+
)
|
| 72 |
+
elif reduction == 16:
|
| 73 |
+
self.regressor = nn.Sequential(
|
| 74 |
+
UpsamplingLayer(in_channels, 128),
|
| 75 |
+
UpsamplingLayer(128, 64),
|
| 76 |
+
UpsamplingLayer(64, 32),
|
| 77 |
+
UpsamplingLayer(32, 16),
|
| 78 |
+
nn.Conv2d(16, 1, kernel_size=1),
|
| 79 |
+
nn.LeakyReLU()
|
| 80 |
+
)
|
| 81 |
+
|
| 82 |
+
self.reset_parameters()
|
| 83 |
+
|
| 84 |
+
def forward(self, x):
|
| 85 |
+
return self.regressor(x)
|
| 86 |
+
|
| 87 |
+
def reset_parameters(self):
|
| 88 |
+
for module in self.modules():
|
| 89 |
+
if isinstance(module, nn.Conv2d):
|
| 90 |
+
nn.init.normal_(module.weight, std=0.01)
|
| 91 |
+
if module.bias is not None:
|
| 92 |
+
nn.init.constant_(module.bias, 0)
|
models/enc_model/transformer.py
ADDED
|
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .mlp import MLP
|
| 2 |
+
|
| 3 |
+
from torch import nn
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class TransformerEncoder(nn.Module):
|
| 7 |
+
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
num_layers: int,
|
| 11 |
+
emb_dim: int,
|
| 12 |
+
num_heads: int,
|
| 13 |
+
dropout: float,
|
| 14 |
+
layer_norm_eps: float,
|
| 15 |
+
mlp_factor: int,
|
| 16 |
+
norm_first: bool,
|
| 17 |
+
activation: nn.Module,
|
| 18 |
+
norm: bool,
|
| 19 |
+
):
|
| 20 |
+
|
| 21 |
+
super(TransformerEncoder, self).__init__()
|
| 22 |
+
|
| 23 |
+
self.layers = nn.ModuleList([
|
| 24 |
+
TransformerEncoderLayer(
|
| 25 |
+
emb_dim, num_heads, dropout, layer_norm_eps,
|
| 26 |
+
mlp_factor, norm_first, activation
|
| 27 |
+
) for _ in range(num_layers)
|
| 28 |
+
])
|
| 29 |
+
|
| 30 |
+
self.norm = nn.LayerNorm(emb_dim, layer_norm_eps) if norm else nn.Identity()
|
| 31 |
+
|
| 32 |
+
def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
|
| 33 |
+
output = src
|
| 34 |
+
for layer in self.layers:
|
| 35 |
+
output = layer(output, pos_emb, src_mask, src_key_padding_mask)
|
| 36 |
+
return self.norm(output)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class TransformerEncoderLayer(nn.Module):
|
| 40 |
+
|
| 41 |
+
def __init__(
|
| 42 |
+
self,
|
| 43 |
+
emb_dim: int,
|
| 44 |
+
num_heads: int,
|
| 45 |
+
dropout: float,
|
| 46 |
+
layer_norm_eps: float,
|
| 47 |
+
mlp_factor: int,
|
| 48 |
+
norm_first: bool,
|
| 49 |
+
activation: nn.Module,
|
| 50 |
+
):
|
| 51 |
+
super(TransformerEncoderLayer, self).__init__()
|
| 52 |
+
|
| 53 |
+
self.norm_first = norm_first
|
| 54 |
+
|
| 55 |
+
self.norm1 = nn.LayerNorm(emb_dim, layer_norm_eps)
|
| 56 |
+
self.norm2 = nn.LayerNorm(emb_dim, layer_norm_eps)
|
| 57 |
+
self.dropout1 = nn.Dropout(dropout)
|
| 58 |
+
self.dropout2 = nn.Dropout(dropout)
|
| 59 |
+
|
| 60 |
+
self.self_attn = nn.MultiheadAttention(
|
| 61 |
+
emb_dim, num_heads, dropout
|
| 62 |
+
)
|
| 63 |
+
self.mlp = MLP(emb_dim, mlp_factor * emb_dim, dropout, activation)
|
| 64 |
+
|
| 65 |
+
def with_emb(self, x, emb):
|
| 66 |
+
return x if emb is None else x + emb
|
| 67 |
+
|
| 68 |
+
def forward(self, src, pos_emb, src_mask, src_key_padding_mask):
|
| 69 |
+
if self.norm_first:
|
| 70 |
+
src_norm = self.norm1(src)
|
| 71 |
+
q = k = src_norm + pos_emb
|
| 72 |
+
src = src + self.dropout1(self.self_attn(
|
| 73 |
+
query=q,
|
| 74 |
+
key=k,
|
| 75 |
+
value=src_norm,
|
| 76 |
+
attn_mask=src_mask,
|
| 77 |
+
key_padding_mask=src_key_padding_mask
|
| 78 |
+
)[0])
|
| 79 |
+
|
| 80 |
+
src_norm = self.norm2(src)
|
| 81 |
+
src = src + self.dropout2(self.mlp(src_norm))
|
| 82 |
+
else:
|
| 83 |
+
q = k = src + pos_emb
|
| 84 |
+
src = self.norm1(src + self.dropout1(self.self_attn(
|
| 85 |
+
query=q,
|
| 86 |
+
key=k,
|
| 87 |
+
value=src,
|
| 88 |
+
attn_mask=src_mask,
|
| 89 |
+
key_padding_mask=src_key_padding_mask
|
| 90 |
+
)[0]))
|
| 91 |
+
|
| 92 |
+
src = self.norm2(src + self.dropout2(self.mlp(src)))
|
| 93 |
+
|
| 94 |
+
return src
|
models/enc_model/unet_parts.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" Parts of the U-Net model """
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn as nn
|
| 5 |
+
import torch.nn.functional as F
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class DoubleConv(nn.Module):
|
| 9 |
+
"""(convolution => [BN] => ReLU) * 2"""
|
| 10 |
+
|
| 11 |
+
def __init__(self, in_channels, out_channels, mid_channels=None):
|
| 12 |
+
super().__init__()
|
| 13 |
+
if not mid_channels:
|
| 14 |
+
mid_channels = out_channels
|
| 15 |
+
self.double_conv = nn.Sequential(
|
| 16 |
+
nn.Conv2d(in_channels, mid_channels, kernel_size=3, padding=1, bias=False),
|
| 17 |
+
nn.BatchNorm2d(mid_channels),
|
| 18 |
+
nn.ReLU(inplace=True),
|
| 19 |
+
nn.Conv2d(mid_channels, out_channels, kernel_size=3, padding=1, bias=False),
|
| 20 |
+
nn.BatchNorm2d(out_channels),
|
| 21 |
+
nn.ReLU(inplace=True)
|
| 22 |
+
)
|
| 23 |
+
|
| 24 |
+
def forward(self, x):
|
| 25 |
+
return self.double_conv(x)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class Down(nn.Module):
|
| 29 |
+
"""Downscaling with maxpool then double conv"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, in_channels, out_channels):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.maxpool_conv = nn.Sequential(
|
| 34 |
+
nn.MaxPool2d(2),
|
| 35 |
+
DoubleConv(in_channels, out_channels)
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
def forward(self, x):
|
| 39 |
+
return self.maxpool_conv(x)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class Up(nn.Module):
|
| 43 |
+
"""Upscaling then double conv"""
|
| 44 |
+
|
| 45 |
+
def __init__(self, in_channels, out_channels, bilinear=True):
|
| 46 |
+
super().__init__()
|
| 47 |
+
|
| 48 |
+
# if bilinear, use the normal convolutions to reduce the number of channels
|
| 49 |
+
if bilinear:
|
| 50 |
+
self.up = nn.Upsample(scale_factor=2, mode='bilinear', align_corners=True)
|
| 51 |
+
self.conv = DoubleConv(in_channels, out_channels, in_channels // 2)
|
| 52 |
+
else:
|
| 53 |
+
self.up = nn.ConvTranspose2d(in_channels, in_channels // 2, kernel_size=2, stride=2)
|
| 54 |
+
self.conv = DoubleConv(in_channels, out_channels)
|
| 55 |
+
|
| 56 |
+
def forward(self, x1, x2):
|
| 57 |
+
x1 = self.up(x1)
|
| 58 |
+
# input is CHW
|
| 59 |
+
diffY = x2.size()[2] - x1.size()[2]
|
| 60 |
+
diffX = x2.size()[3] - x1.size()[3]
|
| 61 |
+
|
| 62 |
+
x1 = F.pad(x1, [diffX // 2, diffX - diffX // 2,
|
| 63 |
+
diffY // 2, diffY - diffY // 2])
|
| 64 |
+
# if you have padding issues, see
|
| 65 |
+
# https://github.com/HaiyongJiang/U-Net-Pytorch-Unstructured-Buggy/commit/0e854509c2cea854e247a9c615f175f76fbb2e3a
|
| 66 |
+
# https://github.com/xiaopeng-liao/Pytorch-UNet/commit/8ebac70e633bac59fc22bb5195e513d5832fb3bd
|
| 67 |
+
x = torch.cat([x2, x1], dim=1)
|
| 68 |
+
return self.conv(x)
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class OutConv(nn.Module):
|
| 72 |
+
def __init__(self, in_channels, out_channels):
|
| 73 |
+
super(OutConv, self).__init__()
|
| 74 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=1)
|
| 75 |
+
|
| 76 |
+
def forward(self, x):
|
| 77 |
+
return self.conv(x)
|
models/model.py
ADDED
|
@@ -0,0 +1,653 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import os
|
| 5 |
+
import clip
|
| 6 |
+
import sys
|
| 7 |
+
import numpy as np
|
| 8 |
+
from models.seg_post_model.cellpose.models import CellposeModel
|
| 9 |
+
|
| 10 |
+
from torchvision.ops import roi_align
|
| 11 |
+
def crop_roi_feat(feat, boxes):
|
| 12 |
+
"""
|
| 13 |
+
feat: 1 x c x h x w
|
| 14 |
+
boxes: m x 4, 4: [y_tl, x_tl, y_br, x_br]
|
| 15 |
+
"""
|
| 16 |
+
_, _, h, w = feat.shape
|
| 17 |
+
out_stride = 512 / h
|
| 18 |
+
boxes_scaled = boxes / out_stride
|
| 19 |
+
boxes_scaled[:, :2] = torch.floor(boxes_scaled[:, :2]) # y_tl, x_tl: floor
|
| 20 |
+
boxes_scaled[:, 2:] = torch.ceil(boxes_scaled[:, 2:]) # y_br, x_br: ceil
|
| 21 |
+
boxes_scaled[:, :2] = torch.clamp_min(boxes_scaled[:, :2], 0)
|
| 22 |
+
boxes_scaled[:, 2] = torch.clamp_max(boxes_scaled[:, 2], h)
|
| 23 |
+
boxes_scaled[:, 3] = torch.clamp_max(boxes_scaled[:, 3], w)
|
| 24 |
+
feat_boxes = []
|
| 25 |
+
for idx_box in range(0, boxes.shape[0]):
|
| 26 |
+
y_tl, x_tl, y_br, x_br = boxes_scaled[idx_box]
|
| 27 |
+
y_tl, x_tl, y_br, x_br = int(y_tl), int(x_tl), int(y_br), int(x_br)
|
| 28 |
+
feat_box = feat[:, :, y_tl : (y_br + 1), x_tl : (x_br + 1)]
|
| 29 |
+
feat_boxes.append(feat_box)
|
| 30 |
+
return feat_boxes
|
| 31 |
+
|
| 32 |
+
class Counting_with_SD_features(nn.Module):
|
| 33 |
+
def __init__(self, scale_factor):
|
| 34 |
+
super(Counting_with_SD_features, self).__init__()
|
| 35 |
+
self.adapter = adapter_roi()
|
| 36 |
+
# self.regressor = regressor_with_SD_features()
|
| 37 |
+
|
| 38 |
+
class Counting_with_SD_features_loca(nn.Module):
|
| 39 |
+
def __init__(self, scale_factor):
|
| 40 |
+
super(Counting_with_SD_features_loca, self).__init__()
|
| 41 |
+
self.adapter = adapter_roi_loca()
|
| 42 |
+
self.regressor = regressor_with_SD_features()
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class Counting_with_SD_features_dino_vit_c3(nn.Module):
|
| 46 |
+
def __init__(self, scale_factor, vit=None):
|
| 47 |
+
super(Counting_with_SD_features_dino_vit_c3, self).__init__()
|
| 48 |
+
self.adapter = adapter_roi_loca()
|
| 49 |
+
self.regressor = regressor_with_SD_features_seg_vit_c3()
|
| 50 |
+
|
| 51 |
+
class Counting_with_SD_features_track(nn.Module):
|
| 52 |
+
def __init__(self, scale_factor, vit=None):
|
| 53 |
+
super(Counting_with_SD_features_track, self).__init__()
|
| 54 |
+
self.adapter = adapter_roi_loca()
|
| 55 |
+
self.regressor = regressor_with_SD_features_tra()
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
class adapter_roi(nn.Module):
|
| 59 |
+
def __init__(self, pool_size=[3, 3]):
|
| 60 |
+
super(adapter_roi, self).__init__()
|
| 61 |
+
self.pool_size = pool_size
|
| 62 |
+
self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 63 |
+
# self.relu = nn.ReLU()
|
| 64 |
+
# self.conv2 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 65 |
+
self.pool = nn.MaxPool2d(2)
|
| 66 |
+
self.fc = nn.Linear(256 * 3 * 3, 768)
|
| 67 |
+
# **new
|
| 68 |
+
self.fc1 = nn.Sequential(
|
| 69 |
+
nn.ReLU(),
|
| 70 |
+
nn.Linear(768, 768 // 4, bias=False),
|
| 71 |
+
nn.ReLU()
|
| 72 |
+
)
|
| 73 |
+
self.fc2 = nn.Sequential(
|
| 74 |
+
nn.Linear(768 // 4, 768, bias=False),
|
| 75 |
+
# nn.ReLU()
|
| 76 |
+
)
|
| 77 |
+
self.initialize_weights()
|
| 78 |
+
|
| 79 |
+
def forward(self, x, boxes):
|
| 80 |
+
num_of_boxes = boxes.shape[1]
|
| 81 |
+
rois = []
|
| 82 |
+
bs, _, h, w = x.shape
|
| 83 |
+
boxes = torch.cat([
|
| 84 |
+
torch.arange(
|
| 85 |
+
bs, requires_grad=False
|
| 86 |
+
).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 87 |
+
boxes.flatten(0, 1),
|
| 88 |
+
], dim=1)
|
| 89 |
+
rois = roi_align(
|
| 90 |
+
x,
|
| 91 |
+
boxes=boxes, output_size=3,
|
| 92 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 93 |
+
)
|
| 94 |
+
rois = torch.mean(rois, dim=0, keepdim=True)
|
| 95 |
+
x = self.conv1(rois)
|
| 96 |
+
x = x.view(x.size(0), -1)
|
| 97 |
+
x = self.fc(x)
|
| 98 |
+
|
| 99 |
+
x = self.fc1(x)
|
| 100 |
+
x = self.fc2(x)
|
| 101 |
+
return x
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def initialize_weights(self):
|
| 105 |
+
for m in self.modules():
|
| 106 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 107 |
+
nn.init.xavier_normal_(m.weight)
|
| 108 |
+
if m.bias is not None:
|
| 109 |
+
nn.init.constant_(m.bias, 0)
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
class adapter_roi_loca(nn.Module):
|
| 113 |
+
def __init__(self, pool_size=[3, 3]):
|
| 114 |
+
super(adapter_roi_loca, self).__init__()
|
| 115 |
+
self.pool_size = pool_size
|
| 116 |
+
self.conv1 = nn.Conv2d(256, 256, kernel_size=3, padding=1)
|
| 117 |
+
self.pool = nn.MaxPool2d(2)
|
| 118 |
+
self.fc = nn.Linear(256 * 3 * 3, 768)
|
| 119 |
+
self.initialize_weights()
|
| 120 |
+
def forward(self, x, boxes):
|
| 121 |
+
num_of_boxes = boxes.shape[1]
|
| 122 |
+
rois = []
|
| 123 |
+
bs, _, h, w = x.shape
|
| 124 |
+
if h != 512 or w != 512:
|
| 125 |
+
x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False)
|
| 126 |
+
if bs == 1:
|
| 127 |
+
boxes = torch.cat([
|
| 128 |
+
torch.arange(
|
| 129 |
+
bs, requires_grad=False
|
| 130 |
+
).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 131 |
+
boxes.flatten(0, 1),
|
| 132 |
+
], dim=1)
|
| 133 |
+
rois = roi_align(
|
| 134 |
+
x,
|
| 135 |
+
boxes=boxes, output_size=3,
|
| 136 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 137 |
+
)
|
| 138 |
+
rois = torch.mean(rois, dim=0, keepdim=True)
|
| 139 |
+
else:
|
| 140 |
+
boxes = torch.cat([
|
| 141 |
+
boxes.flatten(0, 1),
|
| 142 |
+
], dim=1).split(num_of_boxes, dim=0)
|
| 143 |
+
rois = roi_align(
|
| 144 |
+
x,
|
| 145 |
+
boxes=boxes, output_size=3,
|
| 146 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 147 |
+
)
|
| 148 |
+
rois = rois.split(num_of_boxes, dim=0)
|
| 149 |
+
rois = torch.stack(rois, dim=0)
|
| 150 |
+
rois = torch.mean(rois, dim=1, keepdim=False)
|
| 151 |
+
x = self.conv1(rois)
|
| 152 |
+
x = x.view(x.size(0), -1)
|
| 153 |
+
x = self.fc(x)
|
| 154 |
+
return x
|
| 155 |
+
|
| 156 |
+
def forward_boxes(self, x, boxes):
|
| 157 |
+
num_of_boxes = boxes.shape[1]
|
| 158 |
+
rois = []
|
| 159 |
+
bs, _, h, w = x.shape
|
| 160 |
+
if h != 512 or w != 512:
|
| 161 |
+
x = F.interpolate(x, size=(512, 512), mode='bilinear', align_corners=False)
|
| 162 |
+
if bs == 1:
|
| 163 |
+
boxes = torch.cat([
|
| 164 |
+
torch.arange(
|
| 165 |
+
bs, requires_grad=False
|
| 166 |
+
).to(boxes.device).repeat_interleave(num_of_boxes).reshape(-1, 1),
|
| 167 |
+
boxes.flatten(0, 1),
|
| 168 |
+
], dim=1)
|
| 169 |
+
rois = roi_align(
|
| 170 |
+
x,
|
| 171 |
+
boxes=boxes, output_size=3,
|
| 172 |
+
spatial_scale=1.0 / 8, aligned=True
|
| 173 |
+
)
|
| 174 |
+
# rois = torch.mean(rois, dim=0, keepdim=True)
|
| 175 |
+
else:
|
| 176 |
+
raise NotImplementedError
|
| 177 |
+
x = self.conv1(rois)
|
| 178 |
+
x = x.view(x.size(0), -1)
|
| 179 |
+
x = self.fc(x)
|
| 180 |
+
return x
|
| 181 |
+
|
| 182 |
+
def initialize_weights(self):
|
| 183 |
+
for m in self.modules():
|
| 184 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 185 |
+
nn.init.xavier_normal_(m.weight)
|
| 186 |
+
if m.bias is not None:
|
| 187 |
+
nn.init.constant_(m.bias, 0)
|
| 188 |
+
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class regressor1(nn.Module):
|
| 193 |
+
def __init__(self):
|
| 194 |
+
super(regressor1, self).__init__()
|
| 195 |
+
self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 196 |
+
self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 197 |
+
self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
|
| 198 |
+
self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 199 |
+
self.leaky_relu = nn.LeakyReLU()
|
| 200 |
+
self.relu = nn.ReLU()
|
| 201 |
+
self.initialize_weights()
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def forward(self, x):
|
| 206 |
+
x_ = self.conv1(x)
|
| 207 |
+
x_ = self.leaky_relu(x_)
|
| 208 |
+
x_ = self.upsampler(x_)
|
| 209 |
+
x_ = self.conv2(x_)
|
| 210 |
+
x_ = self.leaky_relu(x_)
|
| 211 |
+
x_ = self.upsampler(x_)
|
| 212 |
+
x_ = self.conv3(x_)
|
| 213 |
+
x_ = self.relu(x_)
|
| 214 |
+
out = x_
|
| 215 |
+
return out
|
| 216 |
+
|
| 217 |
+
def initialize_weights(self):
|
| 218 |
+
for m in self.modules():
|
| 219 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 220 |
+
nn.init.xavier_normal_(m.weight)
|
| 221 |
+
if m.bias is not None:
|
| 222 |
+
nn.init.constant_(m.bias, 0)
|
| 223 |
+
|
| 224 |
+
|
| 225 |
+
class regressor1(nn.Module):
|
| 226 |
+
def __init__(self):
|
| 227 |
+
super(regressor1, self).__init__()
|
| 228 |
+
self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 229 |
+
self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 230 |
+
self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
|
| 231 |
+
self.upsampler = nn.UpsamplingBilinear2d(scale_factor=2)
|
| 232 |
+
self.leaky_relu = nn.LeakyReLU()
|
| 233 |
+
self.relu = nn.ReLU()
|
| 234 |
+
|
| 235 |
+
def forward(self, x):
|
| 236 |
+
x_ = self.conv1(x)
|
| 237 |
+
x_ = self.leaky_relu(x_)
|
| 238 |
+
x_ = self.upsampler(x_)
|
| 239 |
+
x_ = self.conv2(x_)
|
| 240 |
+
x_ = self.leaky_relu(x_)
|
| 241 |
+
x_ = self.upsampler(x_)
|
| 242 |
+
x_ = self.conv3(x_)
|
| 243 |
+
x_ = self.relu(x_)
|
| 244 |
+
out = x_
|
| 245 |
+
return out
|
| 246 |
+
def initialize_weights(self):
|
| 247 |
+
for m in self.modules():
|
| 248 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 249 |
+
nn.init.xavier_normal_(m.weight)
|
| 250 |
+
if m.bias is not None:
|
| 251 |
+
nn.init.constant_(m.bias, 0)
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
class regressor_with_SD_features(nn.Module):
|
| 255 |
+
def __init__(self):
|
| 256 |
+
super(regressor_with_SD_features, self).__init__()
|
| 257 |
+
self.layer1 = nn.Sequential(
|
| 258 |
+
nn.Conv2d(324, 256, kernel_size=1, stride=1),
|
| 259 |
+
nn.LeakyReLU(),
|
| 260 |
+
nn.LayerNorm((64, 64))
|
| 261 |
+
)
|
| 262 |
+
self.layer2 = nn.Sequential(
|
| 263 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 264 |
+
nn.LeakyReLU(),
|
| 265 |
+
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 266 |
+
)
|
| 267 |
+
self.layer3 = nn.Sequential(
|
| 268 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 269 |
+
nn.ReLU(),
|
| 270 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 271 |
+
)
|
| 272 |
+
self.layer4 = nn.Sequential(
|
| 273 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 274 |
+
nn.LeakyReLU(),
|
| 275 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 276 |
+
)
|
| 277 |
+
self.conv = nn.Sequential(
|
| 278 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 279 |
+
nn.ReLU()
|
| 280 |
+
)
|
| 281 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 282 |
+
self.initialize_weights()
|
| 283 |
+
|
| 284 |
+
def forward(self, attn_stack, feature_list):
|
| 285 |
+
attn_stack = self.norm(attn_stack)
|
| 286 |
+
unet_feature = feature_list[-1]
|
| 287 |
+
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 288 |
+
unet_feature = unet_feature * attn_stack_mean
|
| 289 |
+
unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 290 |
+
x = self.layer1(unet_feature)
|
| 291 |
+
x = self.layer2(x)
|
| 292 |
+
x = self.layer3(x)
|
| 293 |
+
x = self.layer4(x)
|
| 294 |
+
out = self.conv(x)
|
| 295 |
+
return out / 100
|
| 296 |
+
|
| 297 |
+
def initialize_weights(self):
|
| 298 |
+
for m in self.modules():
|
| 299 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 300 |
+
nn.init.xavier_normal_(m.weight)
|
| 301 |
+
if m.bias is not None:
|
| 302 |
+
nn.init.constant_(m.bias, 0)
|
| 303 |
+
|
| 304 |
+
class regressor_with_SD_features_seg(nn.Module):
|
| 305 |
+
def __init__(self):
|
| 306 |
+
super(regressor_with_SD_features_seg, self).__init__()
|
| 307 |
+
self.layer1 = nn.Sequential(
|
| 308 |
+
nn.Conv2d(324, 256, kernel_size=1, stride=1),
|
| 309 |
+
nn.LeakyReLU(),
|
| 310 |
+
nn.LayerNorm((64, 64))
|
| 311 |
+
)
|
| 312 |
+
self.layer2 = nn.Sequential(
|
| 313 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 314 |
+
nn.LeakyReLU(),
|
| 315 |
+
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 316 |
+
)
|
| 317 |
+
self.layer3 = nn.Sequential(
|
| 318 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 319 |
+
nn.ReLU(),
|
| 320 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 321 |
+
)
|
| 322 |
+
self.layer4 = nn.Sequential(
|
| 323 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 324 |
+
nn.LeakyReLU(),
|
| 325 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 326 |
+
)
|
| 327 |
+
self.conv = nn.Sequential(
|
| 328 |
+
nn.Conv2d(32, 2, kernel_size=1),
|
| 329 |
+
# nn.ReLU()
|
| 330 |
+
)
|
| 331 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 332 |
+
self.initialize_weights()
|
| 333 |
+
|
| 334 |
+
def forward(self, attn_stack, feature_list):
|
| 335 |
+
attn_stack = self.norm(attn_stack)
|
| 336 |
+
unet_feature = feature_list[-1]
|
| 337 |
+
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 338 |
+
unet_feature = unet_feature * attn_stack_mean
|
| 339 |
+
unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 340 |
+
x = self.layer1(unet_feature)
|
| 341 |
+
x = self.layer2(x)
|
| 342 |
+
x = self.layer3(x)
|
| 343 |
+
x = self.layer4(x)
|
| 344 |
+
out = self.conv(x)
|
| 345 |
+
return out
|
| 346 |
+
|
| 347 |
+
def initialize_weights(self):
|
| 348 |
+
for m in self.modules():
|
| 349 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 350 |
+
nn.init.xavier_normal_(m.weight)
|
| 351 |
+
if m.bias is not None:
|
| 352 |
+
nn.init.constant_(m.bias, 0)
|
| 353 |
+
|
| 354 |
+
|
| 355 |
+
from models.enc_model.unet_parts import *
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
class regressor_with_SD_features_seg_vit_c3(nn.Module):
|
| 359 |
+
def __init__(self, n_channels=3, n_classes=2, bilinear=False):
|
| 360 |
+
super(regressor_with_SD_features_seg_vit_c3, self).__init__()
|
| 361 |
+
self.n_channels = n_channels
|
| 362 |
+
self.n_classes = n_classes
|
| 363 |
+
self.bilinear = bilinear
|
| 364 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 365 |
+
self.inc_0 = nn.Conv2d(n_channels, 3, kernel_size=3, padding=1)
|
| 366 |
+
self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
|
| 367 |
+
self.vit = self.vit_model.net
|
| 368 |
+
|
| 369 |
+
def forward(self, img, attn_stack, feature_list):
|
| 370 |
+
attn_stack = attn_stack[:, [1,3], ...]
|
| 371 |
+
attn_stack = self.norm(attn_stack)
|
| 372 |
+
unet_feature = feature_list[-1]
|
| 373 |
+
unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
|
| 374 |
+
|
| 375 |
+
x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 376 |
+
|
| 377 |
+
if x.shape[-1] != 512:
|
| 378 |
+
x = F.interpolate(x, size=(512, 512), mode="bilinear")
|
| 379 |
+
x = self.inc_0(x)
|
| 380 |
+
|
| 381 |
+
|
| 382 |
+
|
| 383 |
+
out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
|
| 384 |
+
if out.dtype == np.uint16:
|
| 385 |
+
out = out.astype(np.int16)
|
| 386 |
+
out = torch.from_numpy(out).unsqueeze(0).to(x.device)
|
| 387 |
+
return out
|
| 388 |
+
|
| 389 |
+
def initialize_weights(self):
|
| 390 |
+
for m in self.modules():
|
| 391 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 392 |
+
nn.init.xavier_normal_(m.weight)
|
| 393 |
+
if m.bias is not None:
|
| 394 |
+
nn.init.constant_(m.bias, 0)
|
| 395 |
+
|
| 396 |
+
class regressor_with_SD_features_tra(nn.Module):
|
| 397 |
+
def __init__(self, n_channels=2, n_classes=2, bilinear=False):
|
| 398 |
+
super(regressor_with_SD_features_tra, self).__init__()
|
| 399 |
+
self.n_channels = n_channels
|
| 400 |
+
self.n_classes = n_classes
|
| 401 |
+
self.bilinear = bilinear
|
| 402 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 403 |
+
|
| 404 |
+
# segmentation
|
| 405 |
+
self.inc_0 = nn.Conv2d(3, 3, kernel_size=3, padding=1)
|
| 406 |
+
self.vit_model = CellposeModel(gpu=True, nchan=3, pretrained_model="", use_bfloat16=False)
|
| 407 |
+
self.vit = self.vit_model.net
|
| 408 |
+
|
| 409 |
+
self.inc_1 = nn.Conv2d(n_channels, 1, kernel_size=3, padding=1)
|
| 410 |
+
self.mlp = nn.Linear(64 * 64, 320)
|
| 411 |
+
# self.vit = self.vit_model.net.float()
|
| 412 |
+
|
| 413 |
+
def forward_seg(self, img, attn_stack, feature_list, mask, training=False):
|
| 414 |
+
attn_stack = attn_stack[:, [1,3], ...]
|
| 415 |
+
attn_stack = self.norm(attn_stack)
|
| 416 |
+
unet_feature = feature_list[-1]
|
| 417 |
+
unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
|
| 418 |
+
x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 419 |
+
|
| 420 |
+
if x.shape[-1] != 512:
|
| 421 |
+
x = F.interpolate(x, size=(512, 512), mode="bilinear")
|
| 422 |
+
x = self.inc_0(x)
|
| 423 |
+
feat = x
|
| 424 |
+
|
| 425 |
+
out = self.vit_model.eval(img.squeeze().cpu().numpy(), feat=x.squeeze().cpu().numpy())[0]
|
| 426 |
+
if out.dtype == np.uint16:
|
| 427 |
+
out = out.astype(np.int16)
|
| 428 |
+
out = torch.from_numpy(out).unsqueeze(0).to(x.device)
|
| 429 |
+
return out, 0., feat
|
| 430 |
+
|
| 431 |
+
def forward(self, attn_prev, feature_list_prev, attn_after, feature_list_after):
|
| 432 |
+
assert attn_prev.shape == attn_after.shape, "attn_prev and attn_after must have the same shape"
|
| 433 |
+
n_instances = attn_prev.shape[0]
|
| 434 |
+
attn_prev = self.norm(attn_prev) # [n_instances, 1, 64, 64]
|
| 435 |
+
attn_after = self.norm(attn_after)
|
| 436 |
+
|
| 437 |
+
x = torch.cat([attn_prev, attn_after], dim=1) # n_instances, 2, 64, 64
|
| 438 |
+
|
| 439 |
+
x = self.inc_1(x)
|
| 440 |
+
x = x.view(1, n_instances, -1) # Flatten the tensor to [n_instances, 64*64*4]
|
| 441 |
+
x = self.mlp(x) # Apply the MLP to get the output
|
| 442 |
+
|
| 443 |
+
return x # Output shape will be [n_instances, 4]
|
| 444 |
+
|
| 445 |
+
|
| 446 |
+
|
| 447 |
+
def initialize_weights(self):
|
| 448 |
+
for m in self.modules():
|
| 449 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 450 |
+
nn.init.xavier_normal_(m.weight)
|
| 451 |
+
if m.bias is not None:
|
| 452 |
+
nn.init.constant_(m.bias, 0)
|
| 453 |
+
|
| 454 |
+
|
| 455 |
+
|
| 456 |
+
class regressor_with_SD_features_inst_seg_unet(nn.Module):
|
| 457 |
+
def __init__(self, n_channels=8, n_classes=3, bilinear=False):
|
| 458 |
+
super(regressor_with_SD_features_inst_seg_unet, self).__init__()
|
| 459 |
+
self.n_channels = n_channels
|
| 460 |
+
self.n_classes = n_classes
|
| 461 |
+
self.bilinear = bilinear
|
| 462 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 463 |
+
self.inc_0 = (DoubleConv(n_channels, 3))
|
| 464 |
+
self.inc = (DoubleConv(3, 64))
|
| 465 |
+
self.down1 = (Down(64, 128))
|
| 466 |
+
self.down2 = (Down(128, 256))
|
| 467 |
+
self.down3 = (Down(256, 512))
|
| 468 |
+
factor = 2 if bilinear else 1
|
| 469 |
+
self.down4 = (Down(512, 1024 // factor))
|
| 470 |
+
self.up1 = (Up(1024, 512 // factor, bilinear))
|
| 471 |
+
self.up2 = (Up(512, 256 // factor, bilinear))
|
| 472 |
+
self.up3 = (Up(256, 128 // factor, bilinear))
|
| 473 |
+
self.up4 = (Up(128, 64, bilinear))
|
| 474 |
+
self.outc = (OutConv(64, n_classes))
|
| 475 |
+
|
| 476 |
+
def forward(self, img, attn_stack, feature_list):
|
| 477 |
+
attn_stack = self.norm(attn_stack)
|
| 478 |
+
unet_feature = feature_list[-1]
|
| 479 |
+
unet_feature_mean = torch.mean(unet_feature, dim=1, keepdim=True)
|
| 480 |
+
attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 481 |
+
unet_feature_mean = unet_feature_mean * attn_stack_mean
|
| 482 |
+
x = torch.cat([unet_feature_mean, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 483 |
+
if x.shape[-1] != 512:
|
| 484 |
+
x = F.interpolate(x, size=(512, 512), mode="bilinear")
|
| 485 |
+
x = torch.cat([img, x], dim=1) # [1, 8, 512, 512]
|
| 486 |
+
x = self.inc_0(x)
|
| 487 |
+
x1 = self.inc(x)
|
| 488 |
+
x2 = self.down1(x1)
|
| 489 |
+
x3 = self.down2(x2)
|
| 490 |
+
x4 = self.down3(x3)
|
| 491 |
+
x5 = self.down4(x4)
|
| 492 |
+
x = self.up1(x5, x4)
|
| 493 |
+
x = self.up2(x, x3)
|
| 494 |
+
x = self.up3(x, x2)
|
| 495 |
+
x = self.up4(x, x1)
|
| 496 |
+
out = self.outc(x)
|
| 497 |
+
return out
|
| 498 |
+
|
| 499 |
+
def initialize_weights(self):
|
| 500 |
+
for m in self.modules():
|
| 501 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 502 |
+
nn.init.xavier_normal_(m.weight)
|
| 503 |
+
if m.bias is not None:
|
| 504 |
+
nn.init.constant_(m.bias, 0)
|
| 505 |
+
|
| 506 |
+
|
| 507 |
+
class regressor_with_SD_features_self(nn.Module):
|
| 508 |
+
def __init__(self):
|
| 509 |
+
super(regressor_with_SD_features_self, self).__init__()
|
| 510 |
+
self.layer = nn.Sequential(
|
| 511 |
+
nn.Conv2d(4096, 1024, kernel_size=1, stride=1),
|
| 512 |
+
nn.LeakyReLU(),
|
| 513 |
+
nn.LayerNorm((64, 64)),
|
| 514 |
+
nn.Conv2d(1024, 256, kernel_size=1, stride=1),
|
| 515 |
+
nn.LeakyReLU(),
|
| 516 |
+
nn.LayerNorm((64, 64)),
|
| 517 |
+
)
|
| 518 |
+
self.layer2 = nn.Sequential(
|
| 519 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 520 |
+
nn.LeakyReLU(),
|
| 521 |
+
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 522 |
+
)
|
| 523 |
+
self.layer3 = nn.Sequential(
|
| 524 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 525 |
+
nn.ReLU(),
|
| 526 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 527 |
+
)
|
| 528 |
+
self.layer4 = nn.Sequential(
|
| 529 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 530 |
+
nn.LeakyReLU(),
|
| 531 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 532 |
+
)
|
| 533 |
+
self.conv = nn.Sequential(
|
| 534 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 535 |
+
nn.ReLU()
|
| 536 |
+
)
|
| 537 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 538 |
+
self.initialize_weights()
|
| 539 |
+
|
| 540 |
+
def forward(self, self_attn):
|
| 541 |
+
self_attn = self_attn.permute(2, 0, 1)
|
| 542 |
+
self_attn = self.layer(self_attn)
|
| 543 |
+
return self_attn
|
| 544 |
+
# attn_stack = self.norm(attn_stack)
|
| 545 |
+
# unet_feature = feature_list[-1]
|
| 546 |
+
# attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 547 |
+
# unet_feature = unet_feature * attn_stack_mean
|
| 548 |
+
# unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 549 |
+
# x = self.layer(unet_feature)
|
| 550 |
+
# x = self.layer2(x)
|
| 551 |
+
# x = self.layer3(x)
|
| 552 |
+
# x = self.layer4(x)
|
| 553 |
+
# out = self.conv(x)
|
| 554 |
+
# return out / 100
|
| 555 |
+
|
| 556 |
+
def initialize_weights(self):
|
| 557 |
+
for m in self.modules():
|
| 558 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 559 |
+
nn.init.xavier_normal_(m.weight)
|
| 560 |
+
if m.bias is not None:
|
| 561 |
+
nn.init.constant_(m.bias, 0)
|
| 562 |
+
|
| 563 |
+
|
| 564 |
+
class regressor_with_SD_features_latent(nn.Module):
|
| 565 |
+
def __init__(self):
|
| 566 |
+
super(regressor_with_SD_features_latent, self).__init__()
|
| 567 |
+
self.layer = nn.Sequential(
|
| 568 |
+
nn.Conv2d(4, 256, kernel_size=1, stride=1),
|
| 569 |
+
nn.LeakyReLU(),
|
| 570 |
+
nn.LayerNorm((64, 64))
|
| 571 |
+
)
|
| 572 |
+
self.layer2 = nn.Sequential(
|
| 573 |
+
nn.Conv2d(256, 128, kernel_size=3, padding=1),
|
| 574 |
+
nn.LeakyReLU(),
|
| 575 |
+
nn.ConvTranspose2d(in_channels=128, out_channels=128, kernel_size=4, stride=2, padding=1),
|
| 576 |
+
)
|
| 577 |
+
self.layer3 = nn.Sequential(
|
| 578 |
+
nn.Conv2d(128, 64, kernel_size=3, padding=1),
|
| 579 |
+
nn.ReLU(),
|
| 580 |
+
nn.ConvTranspose2d(in_channels=64, out_channels=64, kernel_size=4, stride=2, padding=1),
|
| 581 |
+
)
|
| 582 |
+
self.layer4 = nn.Sequential(
|
| 583 |
+
nn.Conv2d(64, 32, kernel_size=3, padding=1),
|
| 584 |
+
nn.LeakyReLU(),
|
| 585 |
+
nn.ConvTranspose2d(in_channels=32, out_channels=32, kernel_size=4, stride=2, padding=1),
|
| 586 |
+
)
|
| 587 |
+
self.conv = nn.Sequential(
|
| 588 |
+
nn.Conv2d(32, 1, kernel_size=1),
|
| 589 |
+
nn.ReLU()
|
| 590 |
+
)
|
| 591 |
+
self.norm = nn.LayerNorm(normalized_shape=(64, 64))
|
| 592 |
+
self.initialize_weights()
|
| 593 |
+
|
| 594 |
+
def forward(self, self_attn):
|
| 595 |
+
# self_attn = self_attn.permute(2, 0, 1)
|
| 596 |
+
self_attn = self.layer(self_attn)
|
| 597 |
+
return self_attn
|
| 598 |
+
# attn_stack = self.norm(attn_stack)
|
| 599 |
+
# unet_feature = feature_list[-1]
|
| 600 |
+
# attn_stack_mean = torch.mean(attn_stack, dim=1, keepdim=True)
|
| 601 |
+
# unet_feature = unet_feature * attn_stack_mean
|
| 602 |
+
# unet_feature = torch.cat([unet_feature, attn_stack], dim=1) # [1, 324, 64, 64]
|
| 603 |
+
# x = self.layer(unet_feature)
|
| 604 |
+
# x = self.layer2(x)
|
| 605 |
+
# x = self.layer3(x)
|
| 606 |
+
# x = self.layer4(x)
|
| 607 |
+
# out = self.conv(x)
|
| 608 |
+
# return out / 100
|
| 609 |
+
|
| 610 |
+
def initialize_weights(self):
|
| 611 |
+
for m in self.modules():
|
| 612 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.Linear):
|
| 613 |
+
nn.init.xavier_normal_(m.weight)
|
| 614 |
+
if m.bias is not None:
|
| 615 |
+
nn.init.constant_(m.bias, 0)
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
|
| 619 |
+
|
| 620 |
+
|
| 621 |
+
class regressor_with_deconv(nn.Module):
|
| 622 |
+
def __init__(self):
|
| 623 |
+
super(regressor_with_deconv, self).__init__()
|
| 624 |
+
self.conv1 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 625 |
+
self.conv2 = nn.Conv2d(4, 4, kernel_size=3, stride=1, padding=1)
|
| 626 |
+
self.conv3 = nn.Conv2d(4, 1, kernel_size=3, stride=1, padding=1)
|
| 627 |
+
self.deconv1 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
|
| 628 |
+
self.deconv2 = nn.ConvTranspose2d(4, 4, kernel_size=4, stride=2, padding=1)
|
| 629 |
+
self.leaky_relu = nn.LeakyReLU()
|
| 630 |
+
self.relu = nn.ReLU()
|
| 631 |
+
self.initialize_weights()
|
| 632 |
+
|
| 633 |
+
def forward(self, x):
|
| 634 |
+
x_ = self.conv1(x)
|
| 635 |
+
x_ = self.leaky_relu(x_)
|
| 636 |
+
x_ = self.deconv1(x_)
|
| 637 |
+
x_ = self.conv2(x_)
|
| 638 |
+
x_ = self.leaky_relu(x_)
|
| 639 |
+
x_ = self.deconv2(x_)
|
| 640 |
+
x_ = self.conv3(x_)
|
| 641 |
+
x_ = self.relu(x_)
|
| 642 |
+
out = x_
|
| 643 |
+
return out
|
| 644 |
+
|
| 645 |
+
def initialize_weights(self):
|
| 646 |
+
for m in self.modules():
|
| 647 |
+
if isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d) or isinstance(m, nn.Linear):
|
| 648 |
+
nn.init.xavier_normal_(m.weight)
|
| 649 |
+
if m.bias is not None:
|
| 650 |
+
nn.init.constant_(m.bias, 0)
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
|
models/seg_post_model/cellpose/__init__.py
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
from .version import version, version_str
|
models/seg_post_model/cellpose/__main__.py
ADDED
|
@@ -0,0 +1,272 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import os, time
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import tqdm
|
| 7 |
+
from cellpose import utils, models, io, train
|
| 8 |
+
from .version import version_str
|
| 9 |
+
from cellpose.cli import get_arg_parser
|
| 10 |
+
|
| 11 |
+
try:
|
| 12 |
+
from cellpose.gui import gui3d, gui
|
| 13 |
+
GUI_ENABLED = True
|
| 14 |
+
except ImportError as err:
|
| 15 |
+
GUI_ERROR = err
|
| 16 |
+
GUI_ENABLED = False
|
| 17 |
+
GUI_IMPORT = True
|
| 18 |
+
except Exception as err:
|
| 19 |
+
GUI_ENABLED = False
|
| 20 |
+
GUI_ERROR = err
|
| 21 |
+
GUI_IMPORT = False
|
| 22 |
+
raise
|
| 23 |
+
|
| 24 |
+
import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def main():
|
| 28 |
+
""" Run cellpose from command line
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
args = get_arg_parser().parse_args() # this has to be in a separate file for autodoc to work
|
| 32 |
+
|
| 33 |
+
if args.version:
|
| 34 |
+
print(version_str)
|
| 35 |
+
return
|
| 36 |
+
|
| 37 |
+
######## if no image arguments are provided, run GUI or add model and exit ########
|
| 38 |
+
if len(args.dir) == 0 and len(args.image_path) == 0:
|
| 39 |
+
if args.add_model:
|
| 40 |
+
io.add_model(args.add_model)
|
| 41 |
+
return
|
| 42 |
+
else:
|
| 43 |
+
if not GUI_ENABLED:
|
| 44 |
+
print("GUI ERROR: %s" % GUI_ERROR)
|
| 45 |
+
if GUI_IMPORT:
|
| 46 |
+
print(
|
| 47 |
+
"GUI FAILED: GUI dependencies may not be installed, to install, run"
|
| 48 |
+
)
|
| 49 |
+
print(" pip install 'cellpose[gui]'")
|
| 50 |
+
else:
|
| 51 |
+
if args.Zstack:
|
| 52 |
+
gui3d.run()
|
| 53 |
+
else:
|
| 54 |
+
gui.run()
|
| 55 |
+
return
|
| 56 |
+
|
| 57 |
+
############################## run cellpose on images ##############################
|
| 58 |
+
if args.verbose:
|
| 59 |
+
from .io import logger_setup
|
| 60 |
+
logger, log_file = logger_setup()
|
| 61 |
+
else:
|
| 62 |
+
print(
|
| 63 |
+
">>>> !LOGGING OFF BY DEFAULT! To see cellpose progress, set --verbose")
|
| 64 |
+
print("No --verbose => no progress or info printed")
|
| 65 |
+
logger = logging.getLogger(__name__)
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
# find images
|
| 69 |
+
if len(args.img_filter) > 0:
|
| 70 |
+
image_filter = args.img_filter
|
| 71 |
+
else:
|
| 72 |
+
image_filter = None
|
| 73 |
+
|
| 74 |
+
device, gpu = models.assign_device(use_torch=True, gpu=args.use_gpu,
|
| 75 |
+
device=args.gpu_device)
|
| 76 |
+
|
| 77 |
+
if args.pretrained_model is None or args.pretrained_model == "None" or args.pretrained_model == "False" or args.pretrained_model == "0":
|
| 78 |
+
pretrained_model = "cpsam"
|
| 79 |
+
logger.warning("training from scratch is disabled, using 'cpsam' model")
|
| 80 |
+
else:
|
| 81 |
+
pretrained_model = args.pretrained_model
|
| 82 |
+
|
| 83 |
+
# Warn users about old arguments from CP3:
|
| 84 |
+
if args.pretrained_model_ortho:
|
| 85 |
+
logger.warning(
|
| 86 |
+
"the '--pretrained_model_ortho' flag is deprecated in v4.0.1+ and no longer used")
|
| 87 |
+
if args.train_size:
|
| 88 |
+
logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
|
| 89 |
+
if args.chan or args.chan2:
|
| 90 |
+
logger.warning('--chan and --chan2 are deprecated, all channels are used by default')
|
| 91 |
+
if args.all_channels:
|
| 92 |
+
logger.warning("the '--all_channels' flag is deprecated in v4.0.1+ and no longer used")
|
| 93 |
+
if args.restore_type:
|
| 94 |
+
logger.warning("the '--restore_type' flag is deprecated in v4.0.1+ and no longer used")
|
| 95 |
+
if args.transformer:
|
| 96 |
+
logger.warning("the '--tranformer' flag is deprecated in v4.0.1+ and no longer used")
|
| 97 |
+
if args.invert:
|
| 98 |
+
logger.warning("the '--invert' flag is deprecated in v4.0.1+ and no longer used")
|
| 99 |
+
if args.chan2_restore:
|
| 100 |
+
logger.warning("the '--chan2_restore' flag is deprecated in v4.0.1+ and no longer used")
|
| 101 |
+
if args.diam_mean:
|
| 102 |
+
logger.warning("the '--diam_mean' flag is deprecated in v4.0.1+ and no longer used")
|
| 103 |
+
if args.train_size:
|
| 104 |
+
logger.warning("the '--train_size' flag is deprecated in v4.0.1+ and no longer used")
|
| 105 |
+
|
| 106 |
+
if args.norm_percentile is not None:
|
| 107 |
+
value1, value2 = args.norm_percentile
|
| 108 |
+
normalize = {'percentile': (float(value1), float(value2))}
|
| 109 |
+
else:
|
| 110 |
+
normalize = (not args.no_norm)
|
| 111 |
+
|
| 112 |
+
if args.save_each:
|
| 113 |
+
if not args.save_every:
|
| 114 |
+
raise ValueError("ERROR: --save_each requires --save_every")
|
| 115 |
+
|
| 116 |
+
if len(args.image_path) > 0 and args.train:
|
| 117 |
+
raise ValueError("ERROR: cannot train model with single image input")
|
| 118 |
+
|
| 119 |
+
## Run evaluation on images
|
| 120 |
+
if not args.train:
|
| 121 |
+
_evaluate_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
|
| 122 |
+
|
| 123 |
+
## Train a model ##
|
| 124 |
+
else:
|
| 125 |
+
_train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def _train_cellposemodel_cli(args, logger, image_filter, device, pretrained_model, normalize):
|
| 129 |
+
test_dir = None if len(args.test_dir) == 0 else args.test_dir
|
| 130 |
+
images, labels, image_names, train_probs = None, None, None, None
|
| 131 |
+
test_images, test_labels, image_names_test, test_probs = None, None, None, None
|
| 132 |
+
compute_flows = False
|
| 133 |
+
if len(args.file_list) > 0:
|
| 134 |
+
if os.path.exists(args.file_list):
|
| 135 |
+
dat = np.load(args.file_list, allow_pickle=True).item()
|
| 136 |
+
image_names = dat["train_files"]
|
| 137 |
+
image_names_test = dat.get("test_files", None)
|
| 138 |
+
train_probs = dat.get("train_probs", None)
|
| 139 |
+
test_probs = dat.get("test_probs", None)
|
| 140 |
+
compute_flows = dat.get("compute_flows", False)
|
| 141 |
+
load_files = False
|
| 142 |
+
else:
|
| 143 |
+
logger.critical(f"ERROR: {args.file_list} does not exist")
|
| 144 |
+
else:
|
| 145 |
+
output = io.load_train_test_data(args.dir, test_dir, image_filter,
|
| 146 |
+
args.mask_filter,
|
| 147 |
+
args.look_one_level_down)
|
| 148 |
+
images, labels, image_names, test_images, test_labels, image_names_test = output
|
| 149 |
+
load_files = True
|
| 150 |
+
|
| 151 |
+
# initialize model
|
| 152 |
+
model = models.CellposeModel(device=device, pretrained_model=pretrained_model)
|
| 153 |
+
|
| 154 |
+
# train segmentation model
|
| 155 |
+
cpmodel_path = train.train_seg(
|
| 156 |
+
model.net, images, labels, train_files=image_names,
|
| 157 |
+
test_data=test_images, test_labels=test_labels,
|
| 158 |
+
test_files=image_names_test, train_probs=train_probs,
|
| 159 |
+
test_probs=test_probs, compute_flows=compute_flows,
|
| 160 |
+
load_files=load_files, normalize=normalize,
|
| 161 |
+
channel_axis=args.channel_axis,
|
| 162 |
+
learning_rate=args.learning_rate, weight_decay=args.weight_decay,
|
| 163 |
+
SGD=args.SGD, n_epochs=args.n_epochs, batch_size=args.train_batch_size,
|
| 164 |
+
min_train_masks=args.min_train_masks,
|
| 165 |
+
nimg_per_epoch=args.nimg_per_epoch,
|
| 166 |
+
nimg_test_per_epoch=args.nimg_test_per_epoch,
|
| 167 |
+
save_path=os.path.realpath(args.dir),
|
| 168 |
+
save_every=args.save_every,
|
| 169 |
+
save_each=args.save_each,
|
| 170 |
+
model_name=args.model_name_out)[0]
|
| 171 |
+
model.pretrained_model = cpmodel_path
|
| 172 |
+
logger.info(">>>> model trained and saved to %s" % cpmodel_path)
|
| 173 |
+
return model
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
def _evaluate_cellposemodel_cli(args, logger, imf, device, pretrained_model, normalize):
|
| 177 |
+
# Check with user if they REALLY mean to run without saving anything
|
| 178 |
+
if not args.train:
|
| 179 |
+
saving_something = args.save_png or args.save_tif or args.save_flows or args.save_txt
|
| 180 |
+
|
| 181 |
+
tic = time.time()
|
| 182 |
+
if len(args.dir) > 0:
|
| 183 |
+
image_names = io.get_image_files(
|
| 184 |
+
args.dir, args.mask_filter, imf=imf,
|
| 185 |
+
look_one_level_down=args.look_one_level_down)
|
| 186 |
+
else:
|
| 187 |
+
if os.path.exists(args.image_path):
|
| 188 |
+
image_names = [args.image_path]
|
| 189 |
+
else:
|
| 190 |
+
raise ValueError(f"ERROR: no file found at {args.image_path}")
|
| 191 |
+
nimg = len(image_names)
|
| 192 |
+
|
| 193 |
+
if args.savedir:
|
| 194 |
+
if not os.path.exists(args.savedir):
|
| 195 |
+
raise FileExistsError(f"--savedir {args.savedir} does not exist")
|
| 196 |
+
|
| 197 |
+
logger.info(
|
| 198 |
+
">>>> running cellpose on %d images using all channels" % nimg)
|
| 199 |
+
|
| 200 |
+
# handle built-in model exceptions
|
| 201 |
+
model = models.CellposeModel(device=device, pretrained_model=pretrained_model,)
|
| 202 |
+
|
| 203 |
+
tqdm_out = utils.TqdmToLogger(logger, level=logging.INFO)
|
| 204 |
+
|
| 205 |
+
channel_axis = args.channel_axis
|
| 206 |
+
z_axis = args.z_axis
|
| 207 |
+
|
| 208 |
+
for image_name in tqdm(image_names, file=tqdm_out):
|
| 209 |
+
if args.do_3D or args.stitch_threshold > 0.:
|
| 210 |
+
logger.info('loading image as 3D zstack')
|
| 211 |
+
image = io.imread_3D(image_name)
|
| 212 |
+
if channel_axis is None:
|
| 213 |
+
channel_axis = 3
|
| 214 |
+
if z_axis is None:
|
| 215 |
+
z_axis = 0
|
| 216 |
+
|
| 217 |
+
else:
|
| 218 |
+
image = io.imread_2D(image_name)
|
| 219 |
+
out = model.eval(
|
| 220 |
+
image,
|
| 221 |
+
diameter=args.diameter,
|
| 222 |
+
do_3D=args.do_3D,
|
| 223 |
+
augment=args.augment,
|
| 224 |
+
flow_threshold=args.flow_threshold,
|
| 225 |
+
cellprob_threshold=args.cellprob_threshold,
|
| 226 |
+
stitch_threshold=args.stitch_threshold,
|
| 227 |
+
min_size=args.min_size,
|
| 228 |
+
batch_size=args.batch_size,
|
| 229 |
+
bsize=args.bsize,
|
| 230 |
+
resample=not args.no_resample,
|
| 231 |
+
normalize=normalize,
|
| 232 |
+
channel_axis=channel_axis,
|
| 233 |
+
z_axis=z_axis,
|
| 234 |
+
anisotropy=args.anisotropy,
|
| 235 |
+
niter=args.niter,
|
| 236 |
+
flow3D_smooth=args.flow3D_smooth)
|
| 237 |
+
masks, flows = out[:2]
|
| 238 |
+
|
| 239 |
+
if args.exclude_on_edges:
|
| 240 |
+
masks = utils.remove_edge_masks(masks)
|
| 241 |
+
if not args.no_npy:
|
| 242 |
+
io.masks_flows_to_seg(image, masks, flows, image_name,
|
| 243 |
+
imgs_restore=None,
|
| 244 |
+
restore_type=None,
|
| 245 |
+
ratio=1.)
|
| 246 |
+
if saving_something:
|
| 247 |
+
suffix = "_cp_masks"
|
| 248 |
+
if args.output_name is not None:
|
| 249 |
+
# (1) If `savedir` is not defined, then must have a non-zero `suffix`
|
| 250 |
+
if args.savedir is None and len(args.output_name) > 0:
|
| 251 |
+
suffix = args.output_name
|
| 252 |
+
elif args.savedir is not None and not os.path.samefile(args.savedir, args.dir):
|
| 253 |
+
# (2) If `savedir` is defined, and different from `dir` then
|
| 254 |
+
# takes the value passed as a param. (which can be empty string)
|
| 255 |
+
suffix = args.output_name
|
| 256 |
+
|
| 257 |
+
io.save_masks(image, masks, flows, image_name,
|
| 258 |
+
suffix=suffix, png=args.save_png,
|
| 259 |
+
tif=args.save_tif, save_flows=args.save_flows,
|
| 260 |
+
save_outlines=args.save_outlines,
|
| 261 |
+
dir_above=args.dir_above, savedir=args.savedir,
|
| 262 |
+
save_txt=args.save_txt, in_folders=args.in_folders,
|
| 263 |
+
save_mpl=args.save_mpl)
|
| 264 |
+
if args.save_rois:
|
| 265 |
+
io.save_rois(masks, image_name)
|
| 266 |
+
logger.info(">>>> completed in %0.3f sec" % (time.time() - tic))
|
| 267 |
+
|
| 268 |
+
return model
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
main()
|
models/seg_post_model/cellpose/cli.py
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2023 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu and Michael Rariden.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import argparse
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def get_arg_parser():
|
| 9 |
+
""" Parses command line arguments for cellpose main function
|
| 10 |
+
|
| 11 |
+
Note: this function has to be in a separate file to allow autodoc to work for CLI.
|
| 12 |
+
The autodoc_mock_imports in conf.py does not work for sphinx-argparse sometimes,
|
| 13 |
+
see https://github.com/ashb/sphinx-argparse/issues/9#issue-1097057823
|
| 14 |
+
"""
|
| 15 |
+
|
| 16 |
+
parser = argparse.ArgumentParser(description="Cellpose Command Line Parameters")
|
| 17 |
+
|
| 18 |
+
# misc settings
|
| 19 |
+
parser.add_argument("--version", action="store_true",
|
| 20 |
+
help="show cellpose version info")
|
| 21 |
+
parser.add_argument(
|
| 22 |
+
"--verbose", action="store_true",
|
| 23 |
+
help="show information about running and settings and save to log")
|
| 24 |
+
parser.add_argument("--Zstack", action="store_true", help="run GUI in 3D mode")
|
| 25 |
+
|
| 26 |
+
# settings for CPU vs GPU
|
| 27 |
+
hardware_args = parser.add_argument_group("Hardware Arguments")
|
| 28 |
+
hardware_args.add_argument("--use_gpu", action="store_true",
|
| 29 |
+
help="use gpu if torch with cuda installed")
|
| 30 |
+
hardware_args.add_argument(
|
| 31 |
+
"--gpu_device", required=False, default="0", type=str,
|
| 32 |
+
help="which gpu device to use, use an integer for torch, or mps for M1")
|
| 33 |
+
|
| 34 |
+
# settings for locating and formatting images
|
| 35 |
+
input_img_args = parser.add_argument_group("Input Image Arguments")
|
| 36 |
+
input_img_args.add_argument("--dir", default=[], type=str,
|
| 37 |
+
help="folder containing data to run or train on.")
|
| 38 |
+
input_img_args.add_argument(
|
| 39 |
+
"--image_path", default=[], type=str, help=
|
| 40 |
+
"if given and --dir not given, run on single image instead of folder (cannot train with this option)"
|
| 41 |
+
)
|
| 42 |
+
input_img_args.add_argument(
|
| 43 |
+
"--look_one_level_down", action="store_true",
|
| 44 |
+
help="run processing on all subdirectories of current folder")
|
| 45 |
+
input_img_args.add_argument("--img_filter", default=[], type=str,
|
| 46 |
+
help="end string for images to run on")
|
| 47 |
+
input_img_args.add_argument(
|
| 48 |
+
"--channel_axis", default=None, type=int,
|
| 49 |
+
help="axis of image which corresponds to image channels")
|
| 50 |
+
input_img_args.add_argument("--z_axis", default=None, type=int,
|
| 51 |
+
help="axis of image which corresponds to Z dimension")
|
| 52 |
+
|
| 53 |
+
# TODO: remove deprecated in future version
|
| 54 |
+
input_img_args.add_argument(
|
| 55 |
+
"--chan", default=0, type=int, help=
|
| 56 |
+
"Deprecated in v4.0.1+, not used. ")
|
| 57 |
+
input_img_args.add_argument(
|
| 58 |
+
"--chan2", default=0, type=int, help=
|
| 59 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 60 |
+
input_img_args.add_argument("--invert", action="store_true", help=
|
| 61 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 62 |
+
input_img_args.add_argument(
|
| 63 |
+
"--all_channels", action="store_true", help=
|
| 64 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 65 |
+
|
| 66 |
+
# model settings
|
| 67 |
+
model_args = parser.add_argument_group("Model Arguments")
|
| 68 |
+
model_args.add_argument("--pretrained_model", required=False, default="cpsam",
|
| 69 |
+
type=str,
|
| 70 |
+
help="model to use for running or starting training")
|
| 71 |
+
model_args.add_argument(
|
| 72 |
+
"--add_model", required=False, default=None, type=str,
|
| 73 |
+
help="model path to copy model to hidden .cellpose folder for using in GUI/CLI")
|
| 74 |
+
model_args.add_argument("--pretrained_model_ortho", required=False, default=None,
|
| 75 |
+
type=str,
|
| 76 |
+
help="Deprecated in v4.0.1+, not used. ")
|
| 77 |
+
|
| 78 |
+
# TODO: remove deprecated in future version
|
| 79 |
+
model_args.add_argument("--restore_type", required=False, default=None, type=str, help=
|
| 80 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 81 |
+
model_args.add_argument("--chan2_restore", action="store_true", help=
|
| 82 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 83 |
+
model_args.add_argument(
|
| 84 |
+
"--transformer", action="store_true", help=
|
| 85 |
+
"use transformer backbone (pretrained_model from Cellpose3 is transformer_cp3)")
|
| 86 |
+
|
| 87 |
+
# algorithm settings
|
| 88 |
+
algorithm_args = parser.add_argument_group("Algorithm Arguments")
|
| 89 |
+
algorithm_args.add_argument("--no_norm", action="store_true",
|
| 90 |
+
help="do not normalize images (normalize=False)")
|
| 91 |
+
algorithm_args.add_argument(
|
| 92 |
+
'--norm_percentile',
|
| 93 |
+
nargs=2, # Require exactly two values
|
| 94 |
+
metavar=('VALUE1', 'VALUE2'),
|
| 95 |
+
help="Provide two float values to set norm_percentile (e.g., --norm_percentile 1 99)"
|
| 96 |
+
)
|
| 97 |
+
algorithm_args.add_argument(
|
| 98 |
+
"--do_3D", action="store_true",
|
| 99 |
+
help="process images as 3D stacks of images (nplanes x nchan x Ly x Lx")
|
| 100 |
+
algorithm_args.add_argument(
|
| 101 |
+
"--diameter", required=False, default=None, type=float, help=
|
| 102 |
+
"use to resize cells to the training diameter (30 pixels)"
|
| 103 |
+
)
|
| 104 |
+
algorithm_args.add_argument(
|
| 105 |
+
"--stitch_threshold", required=False, default=0.0, type=float,
|
| 106 |
+
help="compute masks in 2D then stitch together masks with IoU>0.9 across planes"
|
| 107 |
+
)
|
| 108 |
+
algorithm_args.add_argument(
|
| 109 |
+
"--min_size", required=False, default=15, type=int,
|
| 110 |
+
help="minimum number of pixels per mask, can turn off with -1")
|
| 111 |
+
algorithm_args.add_argument(
|
| 112 |
+
"--flow3D_smooth", required=False, default=0, type=float,
|
| 113 |
+
help="stddev of gaussian for smoothing of dP for dynamics in 3D, default of 0 means no smoothing")
|
| 114 |
+
algorithm_args.add_argument(
|
| 115 |
+
"--flow_threshold", default=0.4, type=float, help=
|
| 116 |
+
"flow error threshold, 0 turns off this optional QC step. Default: %(default)s")
|
| 117 |
+
algorithm_args.add_argument(
|
| 118 |
+
"--cellprob_threshold", default=0, type=float,
|
| 119 |
+
help="cellprob threshold, default is 0, decrease to find more and larger masks")
|
| 120 |
+
algorithm_args.add_argument(
|
| 121 |
+
"--niter", default=0, type=int, help=
|
| 122 |
+
"niter, number of iterations for dynamics for mask creation, default of 0 means it is proportional to diameter, set to a larger number like 2000 for very long ROIs"
|
| 123 |
+
)
|
| 124 |
+
algorithm_args.add_argument("--anisotropy", required=False, default=1.0, type=float,
|
| 125 |
+
help="anisotropy of volume in 3D")
|
| 126 |
+
algorithm_args.add_argument("--exclude_on_edges", action="store_true",
|
| 127 |
+
help="discard masks which touch edges of image")
|
| 128 |
+
algorithm_args.add_argument(
|
| 129 |
+
"--augment", action="store_true",
|
| 130 |
+
help="tiles image with overlapping tiles and flips overlapped regions to augment"
|
| 131 |
+
)
|
| 132 |
+
algorithm_args.add_argument("--batch_size", default=8, type=int,
|
| 133 |
+
help="inference batch size. Default: %(default)s")
|
| 134 |
+
|
| 135 |
+
# TODO: remove deprecated in future version
|
| 136 |
+
algorithm_args.add_argument(
|
| 137 |
+
"--no_resample", action="store_true",
|
| 138 |
+
help="disables flows/cellprob resampling to original image size before computing masks. Using this flag will make more masks more jagged with larger diameter settings.")
|
| 139 |
+
algorithm_args.add_argument(
|
| 140 |
+
"--no_interp", action="store_true",
|
| 141 |
+
help="do not interpolate when running dynamics (was default)")
|
| 142 |
+
|
| 143 |
+
# output settings
|
| 144 |
+
output_args = parser.add_argument_group("Output Arguments")
|
| 145 |
+
output_args.add_argument(
|
| 146 |
+
"--save_png", action="store_true",
|
| 147 |
+
help="save masks as png")
|
| 148 |
+
output_args.add_argument(
|
| 149 |
+
"--save_tif", action="store_true",
|
| 150 |
+
help="save masks as tif")
|
| 151 |
+
output_args.add_argument(
|
| 152 |
+
"--output_name", default=None, type=str,
|
| 153 |
+
help="suffix for saved masks, default is _cp_masks, can be empty if `savedir` used and different of `dir`")
|
| 154 |
+
output_args.add_argument("--no_npy", action="store_true",
|
| 155 |
+
help="suppress saving of npy")
|
| 156 |
+
output_args.add_argument(
|
| 157 |
+
"--savedir", default=None, type=str, help=
|
| 158 |
+
"folder to which segmentation results will be saved (defaults to input image directory)"
|
| 159 |
+
)
|
| 160 |
+
output_args.add_argument(
|
| 161 |
+
"--dir_above", action="store_true", help=
|
| 162 |
+
"save output folders adjacent to image folder instead of inside it (off by default)"
|
| 163 |
+
)
|
| 164 |
+
output_args.add_argument("--in_folders", action="store_true",
|
| 165 |
+
help="flag to save output in folders (off by default)")
|
| 166 |
+
output_args.add_argument(
|
| 167 |
+
"--save_flows", action="store_true", help=
|
| 168 |
+
"whether or not to save RGB images of flows when masks are saved (disabled by default)"
|
| 169 |
+
)
|
| 170 |
+
output_args.add_argument(
|
| 171 |
+
"--save_outlines", action="store_true", help=
|
| 172 |
+
"whether or not to save RGB outline images when masks are saved (disabled by default)"
|
| 173 |
+
)
|
| 174 |
+
output_args.add_argument(
|
| 175 |
+
"--save_rois", action="store_true",
|
| 176 |
+
help="whether or not to save ImageJ compatible ROI archive (disabled by default)"
|
| 177 |
+
)
|
| 178 |
+
output_args.add_argument(
|
| 179 |
+
"--save_txt", action="store_true",
|
| 180 |
+
help="flag to enable txt outlines for ImageJ (disabled by default)")
|
| 181 |
+
output_args.add_argument(
|
| 182 |
+
"--save_mpl", action="store_true",
|
| 183 |
+
help="save a figure of image/mask/flows using matplotlib (disabled by default). "
|
| 184 |
+
"This is slow, especially with large images.")
|
| 185 |
+
|
| 186 |
+
# training settings
|
| 187 |
+
training_args = parser.add_argument_group("Training Arguments")
|
| 188 |
+
training_args.add_argument("--train", action="store_true",
|
| 189 |
+
help="train network using images in dir")
|
| 190 |
+
training_args.add_argument("--test_dir", default=[], type=str,
|
| 191 |
+
help="folder containing test data (optional)")
|
| 192 |
+
training_args.add_argument(
|
| 193 |
+
"--file_list", default=[], type=str, help=
|
| 194 |
+
"path to list of files for training and testing and probabilities for each image (optional)"
|
| 195 |
+
)
|
| 196 |
+
training_args.add_argument(
|
| 197 |
+
"--mask_filter", default="_masks", type=str, help=
|
| 198 |
+
"end string for masks to run on. use '_seg.npy' for manual annotations from the GUI. Default: %(default)s"
|
| 199 |
+
)
|
| 200 |
+
training_args.add_argument("--learning_rate", default=1e-5, type=float,
|
| 201 |
+
help="learning rate. Default: %(default)s")
|
| 202 |
+
training_args.add_argument("--weight_decay", default=0.1, type=float,
|
| 203 |
+
help="weight decay. Default: %(default)s")
|
| 204 |
+
training_args.add_argument("--n_epochs", default=100, type=int,
|
| 205 |
+
help="number of epochs. Default: %(default)s")
|
| 206 |
+
training_args.add_argument("--train_batch_size", default=1, type=int,
|
| 207 |
+
help="training batch size. Default: %(default)s")
|
| 208 |
+
training_args.add_argument("--bsize", default=256, type=int,
|
| 209 |
+
help="block size for tiles. Default: %(default)s")
|
| 210 |
+
training_args.add_argument(
|
| 211 |
+
"--nimg_per_epoch", default=None, type=int,
|
| 212 |
+
help="number of train images per epoch. Default is to use all train images.")
|
| 213 |
+
training_args.add_argument(
|
| 214 |
+
"--nimg_test_per_epoch", default=None, type=int,
|
| 215 |
+
help="number of test images per epoch. Default is to use all test images.")
|
| 216 |
+
training_args.add_argument(
|
| 217 |
+
"--min_train_masks", default=5, type=int, help=
|
| 218 |
+
"minimum number of masks a training image must have to be used. Default: %(default)s"
|
| 219 |
+
)
|
| 220 |
+
training_args.add_argument("--SGD", default=0, type=int,
|
| 221 |
+
help="Deprecated in v4.0.1+, not used - AdamW used instead. ")
|
| 222 |
+
training_args.add_argument(
|
| 223 |
+
"--save_every", default=100, type=int,
|
| 224 |
+
help="number of epochs to skip between saves. Default: %(default)s")
|
| 225 |
+
training_args.add_argument(
|
| 226 |
+
"--save_each", action="store_true",
|
| 227 |
+
help="wether or not to save each epoch. Must also use --save_every. (default: False)")
|
| 228 |
+
training_args.add_argument(
|
| 229 |
+
"--model_name_out", default=None, type=str,
|
| 230 |
+
help="Name of model to save as, defaults to name describing model architecture. "
|
| 231 |
+
"Model is saved in the folder specified by --dir in models subfolder.")
|
| 232 |
+
|
| 233 |
+
# TODO: remove deprecated in future version
|
| 234 |
+
training_args.add_argument(
|
| 235 |
+
"--diam_mean", default=30., type=float, help=
|
| 236 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 237 |
+
training_args.add_argument("--train_size", action="store_true", help=
|
| 238 |
+
'Deprecated in v4.0.1+, not used. ')
|
| 239 |
+
|
| 240 |
+
return parser
|
models/seg_post_model/cellpose/core.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import logging
|
| 5 |
+
import numpy as np
|
| 6 |
+
from tqdm import trange
|
| 7 |
+
from . import transforms, utils
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
|
| 11 |
+
TORCH_ENABLED = True
|
| 12 |
+
|
| 13 |
+
core_logger = logging.getLogger(__name__)
|
| 14 |
+
tqdm_out = utils.TqdmToLogger(core_logger, level=logging.INFO)
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def use_gpu(gpu_number=0, use_torch=True):
|
| 18 |
+
"""
|
| 19 |
+
Check if GPU is available for use.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
gpu_number (int): The index of the GPU to be used. Default is 0.
|
| 23 |
+
use_torch (bool): Whether to use PyTorch for GPU check. Default is True.
|
| 24 |
+
|
| 25 |
+
Returns:
|
| 26 |
+
bool: True if GPU is available, False otherwise.
|
| 27 |
+
|
| 28 |
+
Raises:
|
| 29 |
+
ValueError: If use_torch is False, as cellpose only runs with PyTorch now.
|
| 30 |
+
"""
|
| 31 |
+
if use_torch:
|
| 32 |
+
return _use_gpu_torch(gpu_number)
|
| 33 |
+
else:
|
| 34 |
+
raise ValueError("cellpose only runs with PyTorch now")
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _use_gpu_torch(gpu_number=0):
|
| 38 |
+
"""
|
| 39 |
+
Checks if CUDA or MPS is available and working with PyTorch.
|
| 40 |
+
|
| 41 |
+
Args:
|
| 42 |
+
gpu_number (int): The GPU device number to use (default is 0).
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
bool: True if CUDA or MPS is available and working, False otherwise.
|
| 46 |
+
"""
|
| 47 |
+
try:
|
| 48 |
+
device = torch.device("cuda:" + str(gpu_number))
|
| 49 |
+
_ = torch.zeros((1,1)).to(device)
|
| 50 |
+
core_logger.info("** TORCH CUDA version installed and working. **")
|
| 51 |
+
return True
|
| 52 |
+
except:
|
| 53 |
+
pass
|
| 54 |
+
try:
|
| 55 |
+
device = torch.device('mps:' + str(gpu_number))
|
| 56 |
+
_ = torch.zeros((1,1)).to(device)
|
| 57 |
+
core_logger.info('** TORCH MPS version installed and working. **')
|
| 58 |
+
return True
|
| 59 |
+
except:
|
| 60 |
+
core_logger.info('Neither TORCH CUDA nor MPS version not installed/working.')
|
| 61 |
+
return False
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def assign_device(use_torch=True, gpu=False, device=0):
|
| 65 |
+
"""
|
| 66 |
+
Assigns the device (CPU or GPU or mps) to be used for computation.
|
| 67 |
+
|
| 68 |
+
Args:
|
| 69 |
+
use_torch (bool, optional): Whether to use torch for GPU detection. Defaults to True.
|
| 70 |
+
gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
|
| 71 |
+
device (int or str, optional): The device index or name to be used. Defaults to 0.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
torch.device, bool (True if GPU is used, False otherwise)
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
if isinstance(device, str):
|
| 78 |
+
if device != "mps" or not(gpu and torch.backends.mps.is_available()):
|
| 79 |
+
device = int(device)
|
| 80 |
+
if gpu and use_gpu(use_torch=True):
|
| 81 |
+
try:
|
| 82 |
+
if torch.cuda.is_available():
|
| 83 |
+
device = torch.device(f'cuda:{device}')
|
| 84 |
+
core_logger.info(">>>> using GPU (CUDA)")
|
| 85 |
+
gpu = True
|
| 86 |
+
cpu = False
|
| 87 |
+
except:
|
| 88 |
+
gpu = False
|
| 89 |
+
cpu = True
|
| 90 |
+
try:
|
| 91 |
+
if torch.backends.mps.is_available():
|
| 92 |
+
device = torch.device('mps')
|
| 93 |
+
core_logger.info(">>>> using GPU (MPS)")
|
| 94 |
+
gpu = True
|
| 95 |
+
cpu = False
|
| 96 |
+
except:
|
| 97 |
+
gpu = False
|
| 98 |
+
cpu = True
|
| 99 |
+
else:
|
| 100 |
+
device = torch.device('cpu')
|
| 101 |
+
core_logger.info('>>>> using CPU')
|
| 102 |
+
gpu = False
|
| 103 |
+
cpu = True
|
| 104 |
+
|
| 105 |
+
if cpu:
|
| 106 |
+
device = torch.device("cpu")
|
| 107 |
+
core_logger.info(">>>> using CPU")
|
| 108 |
+
gpu = False
|
| 109 |
+
return device, gpu
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def _to_device(x, device, dtype=torch.float32):
|
| 113 |
+
"""
|
| 114 |
+
Converts the input tensor or numpy array to the specified device.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
x (torch.Tensor or numpy.ndarray): The input tensor or numpy array.
|
| 118 |
+
device (torch.device): The target device.
|
| 119 |
+
|
| 120 |
+
Returns:
|
| 121 |
+
torch.Tensor: The converted tensor on the specified device.
|
| 122 |
+
"""
|
| 123 |
+
if not isinstance(x, torch.Tensor):
|
| 124 |
+
X = torch.from_numpy(x).to(device, dtype=dtype)
|
| 125 |
+
return X
|
| 126 |
+
else:
|
| 127 |
+
return x
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
def _from_device(X):
|
| 131 |
+
"""
|
| 132 |
+
Converts a PyTorch tensor from the device to a NumPy array on the CPU.
|
| 133 |
+
|
| 134 |
+
Args:
|
| 135 |
+
X (torch.Tensor): The input PyTorch tensor.
|
| 136 |
+
|
| 137 |
+
Returns:
|
| 138 |
+
numpy.ndarray: The converted NumPy array.
|
| 139 |
+
"""
|
| 140 |
+
# The cast is so numpy conversion always works
|
| 141 |
+
x = X.detach().cpu().to(torch.float32).numpy()
|
| 142 |
+
return x
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def _forward(net, x, feat=None):
|
| 146 |
+
"""Converts images to torch tensors, runs the network model, and returns numpy arrays.
|
| 147 |
+
|
| 148 |
+
Args:
|
| 149 |
+
net (torch.nn.Module): The network model.
|
| 150 |
+
x (numpy.ndarray): The input images.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
Tuple[numpy.ndarray, numpy.ndarray]: The output predictions (flows and cellprob) and style features.
|
| 154 |
+
"""
|
| 155 |
+
X = _to_device(x, device=net.device, dtype=net.dtype)
|
| 156 |
+
if feat is not None:
|
| 157 |
+
feat = _to_device(feat, device=net.device, dtype=net.dtype)
|
| 158 |
+
net.eval()
|
| 159 |
+
with torch.no_grad():
|
| 160 |
+
y, style = net(X, feat=feat)[:2]
|
| 161 |
+
del X
|
| 162 |
+
y = _from_device(y)
|
| 163 |
+
style = _from_device(style)
|
| 164 |
+
return y, style
|
| 165 |
+
|
| 166 |
+
|
| 167 |
+
def run_net(net, imgi, feat=None, batch_size=8, augment=False, tile_overlap=0.1, bsize=224,
|
| 168 |
+
rsz=None):
|
| 169 |
+
"""
|
| 170 |
+
Run network on stack of images.
|
| 171 |
+
|
| 172 |
+
(faster if augment is False)
|
| 173 |
+
|
| 174 |
+
Args:
|
| 175 |
+
net (class): cellpose network (model.net)
|
| 176 |
+
imgi (np.ndarray): The input image or stack of images of size [Lz x Ly x Lx x nchan].
|
| 177 |
+
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
|
| 178 |
+
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
|
| 179 |
+
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
|
| 180 |
+
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 181 |
+
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
|
| 182 |
+
|
| 183 |
+
Returns:
|
| 184 |
+
Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
|
| 185 |
+
y[...,0] is Y flow; y[...,1] is X flow; y[...,2] is cell probability.
|
| 186 |
+
style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
|
| 187 |
+
"""
|
| 188 |
+
# run network
|
| 189 |
+
Lz, Ly0, Lx0, nchan = imgi.shape
|
| 190 |
+
if rsz is not None:
|
| 191 |
+
if not isinstance(rsz, list) and not isinstance(rsz, np.ndarray):
|
| 192 |
+
rsz = [rsz, rsz]
|
| 193 |
+
Lyr, Lxr = int(Ly0 * rsz[0]), int(Lx0 * rsz[1])
|
| 194 |
+
else:
|
| 195 |
+
Lyr, Lxr = Ly0, Lx0 # 512, 512
|
| 196 |
+
|
| 197 |
+
ly, lx = bsize, bsize # 256, 256
|
| 198 |
+
ypad1, ypad2, xpad1, xpad2 = transforms.get_pad_yx(Lyr, Lxr, min_size=(bsize, bsize)) # 8
|
| 199 |
+
Ly, Lx = Lyr + ypad1 + ypad2, Lxr + xpad1 + xpad2 # 528, 528
|
| 200 |
+
pads = np.array([[0, 0], [ypad1, ypad2], [xpad1, xpad2]])
|
| 201 |
+
|
| 202 |
+
if augment:
|
| 203 |
+
ny = max(2, int(np.ceil(2. * Ly / bsize)))
|
| 204 |
+
nx = max(2, int(np.ceil(2. * Lx / bsize)))
|
| 205 |
+
else:
|
| 206 |
+
ny = 1 if Ly <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Ly / bsize)) # 3
|
| 207 |
+
nx = 1 if Lx <= bsize else int(np.ceil((1. + 2 * tile_overlap) * Lx / bsize)) # 3
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
# run multiple slices at the same time
|
| 211 |
+
ntiles = ny * nx
|
| 212 |
+
nimgs = max(1, batch_size // ntiles) # number of imgs to run in the same batch, 1
|
| 213 |
+
niter = int(np.ceil(Lz / nimgs)) # 1
|
| 214 |
+
ziterator = (trange(niter, file=tqdm_out, mininterval=30)
|
| 215 |
+
if niter > 10 or Lz > 1 else range(niter))
|
| 216 |
+
for k in ziterator:
|
| 217 |
+
inds = np.arange(k * nimgs, min(Lz, (k + 1) * nimgs))
|
| 218 |
+
IMGa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 3, 256, 256
|
| 219 |
+
if feat is not None:
|
| 220 |
+
FEATa = np.zeros((ntiles * len(inds), nchan, ly, lx), "float32") # 9, 256
|
| 221 |
+
else:
|
| 222 |
+
FEATa = None
|
| 223 |
+
for i, b in enumerate(inds):
|
| 224 |
+
# pad image for net so Ly and Lx are divisible by 4
|
| 225 |
+
imgb = transforms.resize_image(imgi[b], rsz=rsz) if rsz is not None else imgi[b].copy()
|
| 226 |
+
imgb = np.pad(imgb.transpose(2,0,1), pads, mode="constant") # 3, 528, 528
|
| 227 |
+
|
| 228 |
+
IMG, ysub, xsub, Lyt, Lxt = transforms.make_tiles(
|
| 229 |
+
imgb, bsize=bsize, augment=augment,
|
| 230 |
+
tile_overlap=tile_overlap) # IMG: 3, 3, 3, 256, 256
|
| 231 |
+
IMGa[i * ntiles : (i+1) * ntiles] = np.reshape(IMG,
|
| 232 |
+
(ny * nx, nchan, ly, lx))
|
| 233 |
+
if feat is not None:
|
| 234 |
+
featb = transforms.resize_image(feat[b], rsz=rsz) if rsz is not None else feat[b].copy()
|
| 235 |
+
featb = np.pad(featb.transpose(2,0,1), pads, mode="constant")
|
| 236 |
+
FEAT, ysub, xsub, Lyt, Lxt = transforms.make_tiles(
|
| 237 |
+
featb, bsize=bsize, augment=augment,
|
| 238 |
+
tile_overlap=tile_overlap)
|
| 239 |
+
FEATa[i * ntiles : (i+1) * ntiles] = np.reshape(FEAT,
|
| 240 |
+
(ny * nx, nchan, ly, lx))
|
| 241 |
+
|
| 242 |
+
# run network
|
| 243 |
+
for j in range(0, IMGa.shape[0], batch_size):
|
| 244 |
+
bslc = slice(j, min(j + batch_size, IMGa.shape[0]))
|
| 245 |
+
ya0, stylea0 = _forward(net, IMGa[bslc], feat=FEATa[bslc] if FEATa is not None else None)
|
| 246 |
+
if j == 0:
|
| 247 |
+
nout = ya0.shape[1]
|
| 248 |
+
ya = np.zeros((IMGa.shape[0], nout, ly, lx), "float32")
|
| 249 |
+
stylea = np.zeros((IMGa.shape[0], 256), "float32")
|
| 250 |
+
ya[bslc] = ya0
|
| 251 |
+
stylea[bslc] = stylea0
|
| 252 |
+
|
| 253 |
+
# average tiles
|
| 254 |
+
for i, b in enumerate(inds):
|
| 255 |
+
if i==0 and k==0:
|
| 256 |
+
yf = np.zeros((Lz, nout, Ly, Lx), "float32")
|
| 257 |
+
styles = np.zeros((Lz, 256), "float32")
|
| 258 |
+
y = ya[i * ntiles : (i + 1) * ntiles]
|
| 259 |
+
if augment:
|
| 260 |
+
y = np.reshape(y, (ny, nx, 3, ly, lx))
|
| 261 |
+
y = transforms.unaugment_tiles(y)
|
| 262 |
+
y = np.reshape(y, (-1, 3, ly, lx))
|
| 263 |
+
yfi = transforms.average_tiles(y, ysub, xsub, Lyt, Lxt)
|
| 264 |
+
yf[b] = yfi[:, :imgb.shape[-2], :imgb.shape[-1]]
|
| 265 |
+
stylei = stylea[i * ntiles:(i + 1) * ntiles].sum(axis=0)
|
| 266 |
+
stylei /= (stylei**2).sum()**0.5
|
| 267 |
+
styles[b] = stylei
|
| 268 |
+
# slices from padding
|
| 269 |
+
yf = yf[:, :, ypad1 : Ly-ypad2, xpad1 : Lx-xpad2]
|
| 270 |
+
yf = yf.transpose(0,2,3,1)
|
| 271 |
+
return yf, np.array(styles)
|
| 272 |
+
|
| 273 |
+
|
| 274 |
+
def run_3D(net, imgs, batch_size=8, augment=False,
|
| 275 |
+
tile_overlap=0.1, bsize=224, net_ortho=None,
|
| 276 |
+
progress=None):
|
| 277 |
+
"""
|
| 278 |
+
Run network on image z-stack.
|
| 279 |
+
|
| 280 |
+
(faster if augment is False)
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
imgs (np.ndarray): The input image stack of size [Lz x Ly x Lx x nchan].
|
| 284 |
+
batch_size (int, optional): Number of tiles to run in a batch. Defaults to 8.
|
| 285 |
+
rsz (float, optional): Resize coefficient(s) for image. Defaults to 1.0.
|
| 286 |
+
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
|
| 287 |
+
augment (bool, optional): Tiles image with overlapping tiles and flips overlapped regions to augment. Defaults to False.
|
| 288 |
+
tile_overlap (float, optional): Fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 289 |
+
bsize (int, optional): Size of tiles to use in pixels [bsize x bsize]. Defaults to 224.
|
| 290 |
+
net_ortho (class, optional): cellpose network for orthogonal ZY and ZX planes. Defaults to None.
|
| 291 |
+
progress (QProgressBar, optional): pyqt progress bar. Defaults to None.
|
| 292 |
+
|
| 293 |
+
Returns:
|
| 294 |
+
Tuple[numpy.ndarray, numpy.ndarray]: outputs of network y and style. If tiled `y` is averaged in tile overlaps. Size of [Ly x Lx x 3] or [Lz x Ly x Lx x 3].
|
| 295 |
+
y[...,0] is Z flow; y[...,1] is Y flow; y[...,2] is X flow; y[...,3] is cell probability.
|
| 296 |
+
style is a 1D array of size 256 summarizing the style of the image, if tiled `style` is averaged over tiles.
|
| 297 |
+
"""
|
| 298 |
+
sstr = ["YX", "ZY", "ZX"]
|
| 299 |
+
pm = [(0, 1, 2, 3), (1, 0, 2, 3), (2, 0, 1, 3)]
|
| 300 |
+
ipm = [(0, 1, 2), (1, 0, 2), (1, 2, 0)]
|
| 301 |
+
cp = [(1, 2), (0, 2), (0, 1)]
|
| 302 |
+
cpy = [(0, 1), (0, 1), (0, 1)]
|
| 303 |
+
shape = imgs.shape[:-1]
|
| 304 |
+
yf = np.zeros((*shape, 4), "float32")
|
| 305 |
+
for p in range(3):
|
| 306 |
+
xsl = imgs.transpose(pm[p])
|
| 307 |
+
# per image
|
| 308 |
+
core_logger.info("running %s: %d planes of size (%d, %d)" %
|
| 309 |
+
(sstr[p], shape[pm[p][0]], shape[pm[p][1]], shape[pm[p][2]]))
|
| 310 |
+
y, style = run_net(net,
|
| 311 |
+
xsl, batch_size=batch_size, augment=augment,
|
| 312 |
+
bsize=bsize, tile_overlap=tile_overlap,
|
| 313 |
+
rsz=None)
|
| 314 |
+
yf[..., -1] += y[..., -1].transpose(ipm[p])
|
| 315 |
+
for j in range(2):
|
| 316 |
+
yf[..., cp[p][j]] += y[..., cpy[p][j]].transpose(ipm[p])
|
| 317 |
+
y = None; del y
|
| 318 |
+
|
| 319 |
+
if progress is not None:
|
| 320 |
+
progress.setValue(25 + 15 * p)
|
| 321 |
+
|
| 322 |
+
return yf, style
|
models/seg_post_model/cellpose/denoise.py
ADDED
|
@@ -0,0 +1,1474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import os, time, datetime
|
| 5 |
+
import numpy as np
|
| 6 |
+
from scipy.stats import mode
|
| 7 |
+
import cv2
|
| 8 |
+
import torch
|
| 9 |
+
from torch import nn
|
| 10 |
+
from torch.nn.functional import conv2d, interpolate
|
| 11 |
+
from tqdm import trange
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
|
| 14 |
+
import logging
|
| 15 |
+
|
| 16 |
+
denoise_logger = logging.getLogger(__name__)
|
| 17 |
+
|
| 18 |
+
from cellpose import transforms, utils, io
|
| 19 |
+
from cellpose.core import run_net
|
| 20 |
+
from cellpose.models import CellposeModel, model_path, normalize_default, assign_device
|
| 21 |
+
|
| 22 |
+
MODEL_NAMES = []
|
| 23 |
+
for ctype in ["cyto3", "cyto2", "nuclei"]:
|
| 24 |
+
for ntype in ["denoise", "deblur", "upsample", "oneclick"]:
|
| 25 |
+
MODEL_NAMES.append(f"{ntype}_{ctype}")
|
| 26 |
+
if ctype != "cyto3":
|
| 27 |
+
for ltype in ["per", "seg", "rec"]:
|
| 28 |
+
MODEL_NAMES.append(f"{ntype}_{ltype}_{ctype}")
|
| 29 |
+
if ctype != "cyto3":
|
| 30 |
+
MODEL_NAMES.append(f"aniso_{ctype}")
|
| 31 |
+
|
| 32 |
+
criterion = nn.MSELoss(reduction="mean")
|
| 33 |
+
criterion2 = nn.BCEWithLogitsLoss(reduction="mean")
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def deterministic(seed=0):
|
| 37 |
+
""" set random seeds to create test data """
|
| 38 |
+
import random
|
| 39 |
+
torch.manual_seed(seed)
|
| 40 |
+
torch.cuda.manual_seed(seed)
|
| 41 |
+
torch.cuda.manual_seed_all(seed) # if you are using multi-GPU.
|
| 42 |
+
np.random.seed(seed) # Numpy module.
|
| 43 |
+
random.seed(seed) # Python random module.
|
| 44 |
+
torch.manual_seed(seed)
|
| 45 |
+
torch.backends.cudnn.benchmark = False
|
| 46 |
+
torch.backends.cudnn.deterministic = True
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def loss_fn_rec(lbl, y):
|
| 50 |
+
""" loss function between true labels lbl and prediction y """
|
| 51 |
+
loss = 80. * criterion(y, lbl)
|
| 52 |
+
return loss
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def loss_fn_seg(lbl, y):
|
| 56 |
+
""" loss function between true labels lbl and prediction y """
|
| 57 |
+
veci = 5. * lbl[:, 1:]
|
| 58 |
+
lbl = (lbl[:, 0] > .5).float()
|
| 59 |
+
loss = criterion(y[:, :2], veci)
|
| 60 |
+
loss /= 2.
|
| 61 |
+
loss2 = criterion2(y[:, 2], lbl)
|
| 62 |
+
loss = loss + loss2
|
| 63 |
+
return loss
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def get_sigma(Tdown):
|
| 67 |
+
""" Calculates the correlation matrices across channels for the perceptual loss.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
Tdown (list): List of tensors output by each downsampling block of network.
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
list: List of correlations for each input tensor.
|
| 74 |
+
"""
|
| 75 |
+
Tnorm = [x - x.mean((-2, -1), keepdim=True) for x in Tdown]
|
| 76 |
+
Tnorm = [x / x.std((-2, -1), keepdim=True) for x in Tnorm]
|
| 77 |
+
Sigma = [
|
| 78 |
+
torch.einsum("bnxy, bmxy -> bnm", x, x) / (x.shape[-2] * x.shape[-1])
|
| 79 |
+
for x in Tnorm
|
| 80 |
+
]
|
| 81 |
+
return Sigma
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def imstats(X, net1):
|
| 85 |
+
"""
|
| 86 |
+
Calculates the image correlation matrices for the perceptual loss.
|
| 87 |
+
|
| 88 |
+
Args:
|
| 89 |
+
X (torch.Tensor): Input image tensor.
|
| 90 |
+
net1: Cellpose net.
|
| 91 |
+
|
| 92 |
+
Returns:
|
| 93 |
+
list: A list of tensors of correlation matrices.
|
| 94 |
+
"""
|
| 95 |
+
_, _, Tdown = net1(X)
|
| 96 |
+
Sigma = get_sigma(Tdown)
|
| 97 |
+
Sigma = [x.detach() for x in Sigma]
|
| 98 |
+
return Sigma
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def loss_fn_per(img, net1, yl):
|
| 102 |
+
"""
|
| 103 |
+
Calculates the perceptual loss function for image restoration.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
img (torch.Tensor): Input image tensor (noisy/blurry/downsampled).
|
| 107 |
+
net1 (torch.nn.Module): Perceptual loss net (Cellpose segmentation net).
|
| 108 |
+
yl (torch.Tensor): Clean image tensor.
|
| 109 |
+
|
| 110 |
+
Returns:
|
| 111 |
+
torch.Tensor: Mean perceptual loss.
|
| 112 |
+
"""
|
| 113 |
+
Sigma = imstats(img, net1)
|
| 114 |
+
sd = [x.std((1, 2)) + 1e-6 for x in Sigma]
|
| 115 |
+
Sigma_test = get_sigma(yl)
|
| 116 |
+
losses = torch.zeros(len(Sigma[0]), device=img.device)
|
| 117 |
+
for k in range(len(Sigma)):
|
| 118 |
+
losses = losses + (((Sigma_test[k] - Sigma[k])**2).mean((1, 2)) / sd[k]**2)
|
| 119 |
+
return losses.mean()
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def test_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
|
| 123 |
+
"""
|
| 124 |
+
Calculates the test loss for image restoration tasks.
|
| 125 |
+
|
| 126 |
+
Args:
|
| 127 |
+
net0 (torch.nn.Module): The image restoration network.
|
| 128 |
+
X (torch.Tensor): The input image tensor.
|
| 129 |
+
net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
|
| 130 |
+
img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
|
| 131 |
+
lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
|
| 132 |
+
lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
|
| 133 |
+
|
| 134 |
+
Returns:
|
| 135 |
+
tuple: A tuple containing the total loss and the perceptual loss.
|
| 136 |
+
"""
|
| 137 |
+
net0.eval()
|
| 138 |
+
if net1 is not None:
|
| 139 |
+
net1.eval()
|
| 140 |
+
loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
|
| 141 |
+
|
| 142 |
+
with torch.no_grad():
|
| 143 |
+
img_dn = net0(X)[0]
|
| 144 |
+
if lam[2] > 0.:
|
| 145 |
+
loss += lam[2] * loss_fn_rec(img, img_dn)
|
| 146 |
+
if lam[1] > 0. or lam[0] > 0.:
|
| 147 |
+
y, _, ydown = net1(img_dn)
|
| 148 |
+
if lam[1] > 0.:
|
| 149 |
+
loss += lam[1] * loss_fn_seg(lbl, y)
|
| 150 |
+
if lam[0] > 0.:
|
| 151 |
+
loss_per = loss_fn_per(img, net1, ydown)
|
| 152 |
+
loss += lam[0] * loss_per
|
| 153 |
+
return loss, loss_per
|
| 154 |
+
|
| 155 |
+
|
| 156 |
+
def train_loss(net0, X, net1=None, img=None, lbl=None, lam=[1., 1.5, 0.]):
|
| 157 |
+
"""
|
| 158 |
+
Calculates the train loss for image restoration tasks.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
net0 (torch.nn.Module): The image restoration network.
|
| 162 |
+
X (torch.Tensor): The input image tensor.
|
| 163 |
+
net1 (torch.nn.Module, optional): The segmentation network for segmentation or perceptual loss. Defaults to None.
|
| 164 |
+
img (torch.Tensor, optional): Clean image tensor for perceptual or reconstruction loss. Defaults to None.
|
| 165 |
+
lbl (torch.Tensor, optional): The ground truth flows/cellprob tensor for segmentation loss. Defaults to None.
|
| 166 |
+
lam (list, optional): The weights for different loss components (perceptual, segmentation, reconstruction). Defaults to [1., 1.5, 0.].
|
| 167 |
+
|
| 168 |
+
Returns:
|
| 169 |
+
tuple: A tuple containing the total loss and the perceptual loss.
|
| 170 |
+
"""
|
| 171 |
+
net0.train()
|
| 172 |
+
if net1 is not None:
|
| 173 |
+
net1.eval()
|
| 174 |
+
loss, loss_per = torch.zeros(1, device=X.device), torch.zeros(1, device=X.device)
|
| 175 |
+
|
| 176 |
+
img_dn = net0(X)[0]
|
| 177 |
+
if lam[2] > 0.:
|
| 178 |
+
loss += lam[2] * loss_fn_rec(img, img_dn)
|
| 179 |
+
if lam[1] > 0. or lam[0] > 0.:
|
| 180 |
+
y, _, ydown = net1(img_dn)
|
| 181 |
+
if lam[1] > 0.:
|
| 182 |
+
loss += lam[1] * loss_fn_seg(lbl, y)
|
| 183 |
+
if lam[0] > 0.:
|
| 184 |
+
loss_per = loss_fn_per(img, net1, ydown)
|
| 185 |
+
loss += lam[0] * loss_per
|
| 186 |
+
return loss, loss_per
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
def img_norm(imgi):
|
| 190 |
+
"""
|
| 191 |
+
Normalizes the input image by subtracting the 1st percentile and dividing by the difference between the 99th and 1st percentiles.
|
| 192 |
+
|
| 193 |
+
Args:
|
| 194 |
+
imgi (torch.Tensor): Input image tensor.
|
| 195 |
+
|
| 196 |
+
Returns:
|
| 197 |
+
torch.Tensor: Normalized image tensor.
|
| 198 |
+
"""
|
| 199 |
+
shape = imgi.shape
|
| 200 |
+
imgi = imgi.reshape(imgi.shape[0], imgi.shape[1], -1)
|
| 201 |
+
perc = torch.quantile(imgi, torch.tensor([0.01, 0.99], device=imgi.device), dim=-1,
|
| 202 |
+
keepdim=True)
|
| 203 |
+
for k in range(imgi.shape[1]):
|
| 204 |
+
hask = (perc[1, :, k, 0] - perc[0, :, k, 0]) > 1e-3
|
| 205 |
+
imgi[hask, k] -= perc[0, hask, k]
|
| 206 |
+
imgi[hask, k] /= (perc[1, hask, k] - perc[0, hask, k])
|
| 207 |
+
imgi = imgi.reshape(shape)
|
| 208 |
+
return imgi
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
def add_noise(lbl, alpha=4, beta=0.7, poisson=0.7, blur=0.7, gblur=1.0, downsample=0.7,
|
| 212 |
+
ds_max=7, diams=None, pscale=None, iso=True, sigma0=None, sigma1=None,
|
| 213 |
+
ds=None, uniform_blur=False, partial_blur=False):
|
| 214 |
+
"""Adds noise to the input image.
|
| 215 |
+
|
| 216 |
+
Args:
|
| 217 |
+
lbl (torch.Tensor): The input image tensor of shape (nimg, nchan, Ly, Lx).
|
| 218 |
+
alpha (float, optional): The shape parameter of the gamma distribution used for generating poisson noise. Defaults to 4.
|
| 219 |
+
beta (float, optional): The rate parameter of the gamma distribution used for generating poisson noise. Defaults to 0.7.
|
| 220 |
+
poisson (float, optional): The probability of adding poisson noise to the image. Defaults to 0.7.
|
| 221 |
+
blur (float, optional): The probability of adding gaussian blur to the image. Defaults to 0.7.
|
| 222 |
+
gblur (float, optional): The scale factor for the gaussian blur. Defaults to 1.0.
|
| 223 |
+
downsample (float, optional): The probability of downsampling the image. Defaults to 0.7.
|
| 224 |
+
ds_max (int, optional): The maximum downsampling factor. Defaults to 7.
|
| 225 |
+
diams (torch.Tensor, optional): The diameter of the objects in the image. Defaults to None.
|
| 226 |
+
pscale (torch.Tensor, optional): The scale factor for the poisson noise, instead of sampling. Defaults to None.
|
| 227 |
+
iso (bool, optional): Whether to use isotropic gaussian blur. Defaults to True.
|
| 228 |
+
sigma0 (torch.Tensor, optional): The standard deviation of the gaussian filter for the Y axis, instead of sampling. Defaults to None.
|
| 229 |
+
sigma1 (torch.Tensor, optional): The standard deviation of the gaussian filter for the X axis, instead of sampling. Defaults to None.
|
| 230 |
+
ds (torch.Tensor, optional): The downsampling factor for each image, instead of sampling. Defaults to None.
|
| 231 |
+
|
| 232 |
+
Returns:
|
| 233 |
+
torch.Tensor: The noisy image tensor of the same shape as the input image.
|
| 234 |
+
"""
|
| 235 |
+
device = lbl.device
|
| 236 |
+
imgi = torch.zeros_like(lbl)
|
| 237 |
+
Ly, Lx = lbl.shape[-2:]
|
| 238 |
+
|
| 239 |
+
diams = diams if diams is not None else 30. * torch.ones(len(lbl), device=device)
|
| 240 |
+
#ds0 = 1 if ds is None else ds.item()
|
| 241 |
+
ds = ds * torch.ones(
|
| 242 |
+
(len(lbl),), device=device, dtype=torch.long) if ds is not None else ds
|
| 243 |
+
|
| 244 |
+
# downsample
|
| 245 |
+
ii = []
|
| 246 |
+
idownsample = np.random.rand(len(lbl)) < downsample
|
| 247 |
+
if (ds is None and idownsample.sum() > 0.) or not iso:
|
| 248 |
+
ds = torch.ones(len(lbl), dtype=torch.long, device=device)
|
| 249 |
+
ds[idownsample] = torch.randint(2, ds_max + 1, size=(idownsample.sum(),),
|
| 250 |
+
device=device)
|
| 251 |
+
ii = torch.nonzero(ds > 1).flatten()
|
| 252 |
+
elif ds is not None and (ds > 1).sum():
|
| 253 |
+
ii = torch.nonzero(ds > 1).flatten()
|
| 254 |
+
|
| 255 |
+
# add gaussian blur
|
| 256 |
+
iblur = torch.rand(len(lbl), device=device) < blur
|
| 257 |
+
iblur[ii] = True
|
| 258 |
+
if iblur.sum() > 0:
|
| 259 |
+
if sigma0 is None:
|
| 260 |
+
if uniform_blur and iso:
|
| 261 |
+
xr = torch.rand(len(lbl), device=device)
|
| 262 |
+
if len(ii) > 0:
|
| 263 |
+
xr[ii] = ds[ii].float() / 2. / gblur
|
| 264 |
+
sigma0 = diams[iblur] / 30. * gblur * (1 / gblur + (1 - 1 / gblur) * xr[iblur])
|
| 265 |
+
sigma1 = sigma0.clone()
|
| 266 |
+
elif not iso:
|
| 267 |
+
xr = torch.rand(len(lbl), device=device)
|
| 268 |
+
if len(ii) > 0:
|
| 269 |
+
xr[ii] = (ds[ii].float()) / gblur
|
| 270 |
+
xr[ii] = xr[ii] + torch.rand(len(ii), device=device) * 0.7 - 0.35
|
| 271 |
+
xr[ii] = torch.clip(xr[ii], 0.05, 1.5)
|
| 272 |
+
sigma0 = diams[iblur] / 30. * gblur * xr[iblur]
|
| 273 |
+
sigma1 = sigma0.clone() / 10.
|
| 274 |
+
else:
|
| 275 |
+
xrand = np.random.exponential(1, size=iblur.sum())
|
| 276 |
+
xrand = np.clip(xrand * 0.5, 0.1, 1.0)
|
| 277 |
+
xrand *= gblur
|
| 278 |
+
sigma0 = diams[iblur] / 30. * 5. * torch.from_numpy(xrand).float().to(
|
| 279 |
+
device)
|
| 280 |
+
sigma1 = sigma0.clone()
|
| 281 |
+
else:
|
| 282 |
+
sigma0 = sigma0 * torch.ones((iblur.sum(),), device=device)
|
| 283 |
+
sigma1 = sigma1 * torch.ones((iblur.sum(),), device=device)
|
| 284 |
+
|
| 285 |
+
# create gaussian filter
|
| 286 |
+
xr = max(8, sigma0.max().long() * 2)
|
| 287 |
+
gfilt0 = torch.exp(-torch.arange(-xr + 1, xr, device=device)**2 /
|
| 288 |
+
(2 * sigma0.unsqueeze(-1)**2))
|
| 289 |
+
gfilt0 /= gfilt0.sum(axis=-1, keepdims=True)
|
| 290 |
+
gfilt1 = torch.zeros_like(gfilt0)
|
| 291 |
+
gfilt1[sigma1 == sigma0] = gfilt0[sigma1 == sigma0]
|
| 292 |
+
gfilt1[sigma1 != sigma0] = torch.exp(
|
| 293 |
+
-torch.arange(-xr + 1, xr, device=device)**2 /
|
| 294 |
+
(2 * sigma1[sigma1 != sigma0].unsqueeze(-1)**2))
|
| 295 |
+
gfilt1[sigma1 == 0] = 0.
|
| 296 |
+
gfilt1[sigma1 == 0, xr] = 1.
|
| 297 |
+
gfilt1 /= gfilt1.sum(axis=-1, keepdims=True)
|
| 298 |
+
gfilt = torch.einsum("ck,cl->ckl", gfilt0, gfilt1)
|
| 299 |
+
gfilt /= gfilt.sum(axis=(1, 2), keepdims=True)
|
| 300 |
+
|
| 301 |
+
lbl_blur = conv2d(lbl[iblur].transpose(1, 0), gfilt.unsqueeze(1),
|
| 302 |
+
padding=gfilt.shape[-1] // 2,
|
| 303 |
+
groups=gfilt.shape[0]).transpose(1, 0)
|
| 304 |
+
if partial_blur:
|
| 305 |
+
#yc, xc = np.random.randint(100, Ly-100), np.random.randint(100, Lx-100)
|
| 306 |
+
imgi[iblur] = lbl[iblur].clone()
|
| 307 |
+
Lxc = int(Lx * 0.85)
|
| 308 |
+
ym, xm = torch.meshgrid(torch.zeros(Ly, dtype=torch.float32),
|
| 309 |
+
torch.arange(0, Lxc, dtype=torch.float32),
|
| 310 |
+
indexing="ij")
|
| 311 |
+
mask = torch.exp(-(ym**2 + xm**2) / 2*(0.001**2))
|
| 312 |
+
mask -= mask.min()
|
| 313 |
+
mask /= mask.max()
|
| 314 |
+
lbl_blur_crop = lbl_blur[:, :, :, :Lxc]
|
| 315 |
+
imgi[iblur, :, :, :Lxc] = (lbl_blur_crop * mask +
|
| 316 |
+
(1-mask) * imgi[iblur, :, :, :Lxc])
|
| 317 |
+
else:
|
| 318 |
+
imgi[iblur] = lbl_blur
|
| 319 |
+
|
| 320 |
+
imgi[~iblur] = lbl[~iblur]
|
| 321 |
+
|
| 322 |
+
# apply downsample
|
| 323 |
+
for k in ii:
|
| 324 |
+
i0 = imgi[k:k + 1, :, ::ds[k], ::ds[k]] if iso else imgi[k:k + 1, :, ::ds[k]]
|
| 325 |
+
imgi[k] = interpolate(i0, size=lbl[k].shape[-2:], mode="bilinear")
|
| 326 |
+
|
| 327 |
+
# add poisson noise
|
| 328 |
+
ipoisson = np.random.rand(len(lbl)) < poisson
|
| 329 |
+
if ipoisson.sum() > 0:
|
| 330 |
+
if pscale is None:
|
| 331 |
+
pscale = torch.zeros(len(lbl))
|
| 332 |
+
m = torch.distributions.gamma.Gamma(alpha, beta)
|
| 333 |
+
pscale = torch.clamp(m.rsample(sample_shape=(ipoisson.sum(),)), 1.)
|
| 334 |
+
#pscale = torch.clamp(20 * (torch.rand(size=(len(lbl),), device=lbl.device)), 1.5)
|
| 335 |
+
pscale = pscale.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(device)
|
| 336 |
+
else:
|
| 337 |
+
pscale = pscale * torch.ones((ipoisson.sum(), 1, 1, 1), device=device)
|
| 338 |
+
imgi[ipoisson] = torch.poisson(pscale * imgi[ipoisson])
|
| 339 |
+
imgi[~ipoisson] = imgi[~ipoisson]
|
| 340 |
+
|
| 341 |
+
# renormalize
|
| 342 |
+
imgi = img_norm(imgi)
|
| 343 |
+
|
| 344 |
+
return imgi
|
| 345 |
+
|
| 346 |
+
|
| 347 |
+
def random_rotate_and_resize_noise(data, labels=None, diams=None, poisson=0.7, blur=0.7,
|
| 348 |
+
downsample=0.0, beta=0.7, gblur=1.0, diam_mean=30,
|
| 349 |
+
ds_max=7, uniform_blur=False, iso=True, rotate=True,
|
| 350 |
+
device=torch.device("cuda"), xy=(224, 224),
|
| 351 |
+
nchan_noise=1, keep_raw=True):
|
| 352 |
+
"""
|
| 353 |
+
Applies random rotation, resizing, and noise to the input data.
|
| 354 |
+
|
| 355 |
+
Args:
|
| 356 |
+
data (numpy.ndarray): The input data.
|
| 357 |
+
labels (numpy.ndarray, optional): The flow and cellprob labels associated with the data. Defaults to None.
|
| 358 |
+
diams (float, optional): The diameter of the objects. Defaults to None.
|
| 359 |
+
poisson (float, optional): The Poisson noise probability. Defaults to 0.7.
|
| 360 |
+
blur (float, optional): The blur probability. Defaults to 0.7.
|
| 361 |
+
downsample (float, optional): The downsample probability. Defaults to 0.0.
|
| 362 |
+
beta (float, optional): The beta value for the poisson noise distribution. Defaults to 0.7.
|
| 363 |
+
gblur (float, optional): The Gaussian blur level. Defaults to 1.0.
|
| 364 |
+
diam_mean (float, optional): The mean diameter. Defaults to 30.
|
| 365 |
+
ds_max (int, optional): The maximum downsample value. Defaults to 7.
|
| 366 |
+
iso (bool, optional): Whether to apply isotropic augmentation. Defaults to True.
|
| 367 |
+
rotate (bool, optional): Whether to apply rotation augmentation. Defaults to True.
|
| 368 |
+
device (torch.device, optional): The device to use. Defaults to torch.device("cuda").
|
| 369 |
+
xy (tuple, optional): The size of the output image. Defaults to (224, 224).
|
| 370 |
+
nchan_noise (int, optional): The number of channels to add noise to. Defaults to 1.
|
| 371 |
+
keep_raw (bool, optional): Whether to keep the raw image. Defaults to True.
|
| 372 |
+
|
| 373 |
+
Returns:
|
| 374 |
+
torch.Tensor: The augmented image and augmented noisy/blurry/downsampled version of image.
|
| 375 |
+
torch.Tensor: The augmented labels.
|
| 376 |
+
float: The scale factor applied to the image.
|
| 377 |
+
"""
|
| 378 |
+
if device == None:
|
| 379 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
|
| 380 |
+
|
| 381 |
+
diams = 30 if diams is None else diams
|
| 382 |
+
random_diam = diam_mean * (2**(2 * np.random.rand(len(data)) - 1))
|
| 383 |
+
random_rsc = diams / random_diam #/ random_diam
|
| 384 |
+
#rsc /= random_scale
|
| 385 |
+
xy0 = (340, 340)
|
| 386 |
+
nchan = data[0].shape[0]
|
| 387 |
+
data_new = np.zeros((len(data), (1 + keep_raw) * nchan, xy0[0], xy0[1]), "float32")
|
| 388 |
+
labels_new = np.zeros((len(data), 3, xy0[0], xy0[1]), "float32")
|
| 389 |
+
for i in range(
|
| 390 |
+
len(data)): #, (sc, img, lbl) in enumerate(zip(random_rsc, data, labels)):
|
| 391 |
+
sc = random_rsc[i]
|
| 392 |
+
img = data[i]
|
| 393 |
+
lbl = labels[i] if labels is not None else None
|
| 394 |
+
# create affine transform to resize
|
| 395 |
+
Ly, Lx = img.shape[-2:]
|
| 396 |
+
dxy = np.maximum(0, np.array([Lx / sc - xy0[1], Ly / sc - xy0[0]]))
|
| 397 |
+
dxy = (np.random.rand(2,) - .5) * dxy
|
| 398 |
+
cc = np.array([Lx / 2, Ly / 2])
|
| 399 |
+
cc1 = cc - np.array([Lx - xy0[1], Ly - xy0[0]]) / 2 + dxy
|
| 400 |
+
pts1 = np.float32([cc, cc + np.array([1, 0]), cc + np.array([0, 1])])
|
| 401 |
+
pts2 = np.float32(
|
| 402 |
+
[cc1, cc1 + np.array([1, 0]) / sc, cc1 + np.array([0, 1]) / sc])
|
| 403 |
+
M = cv2.getAffineTransform(pts1, pts2)
|
| 404 |
+
|
| 405 |
+
# apply to image
|
| 406 |
+
for c in range(nchan):
|
| 407 |
+
img_rsz = cv2.warpAffine(img[c], M, xy0, flags=cv2.INTER_LINEAR)
|
| 408 |
+
#img_noise = add_noise(torch.from_numpy(img_rsz).to(device).unsqueeze(0)).cpu().numpy().squeeze(0)
|
| 409 |
+
data_new[i, c] = img_rsz
|
| 410 |
+
if keep_raw:
|
| 411 |
+
data_new[i, c + nchan] = img_rsz
|
| 412 |
+
|
| 413 |
+
if lbl is not None:
|
| 414 |
+
# apply to labels
|
| 415 |
+
labels_new[i, 0] = cv2.warpAffine(lbl[0], M, xy0, flags=cv2.INTER_NEAREST)
|
| 416 |
+
labels_new[i, 1] = cv2.warpAffine(lbl[1], M, xy0, flags=cv2.INTER_LINEAR)
|
| 417 |
+
labels_new[i, 2] = cv2.warpAffine(lbl[2], M, xy0, flags=cv2.INTER_LINEAR)
|
| 418 |
+
|
| 419 |
+
rsc = random_diam / diam_mean
|
| 420 |
+
|
| 421 |
+
# add noise before augmentations
|
| 422 |
+
img = torch.from_numpy(data_new).to(device)
|
| 423 |
+
img = torch.clamp(img, 0.)
|
| 424 |
+
# just add noise to cyto if nchan_noise=1
|
| 425 |
+
img[:, :nchan_noise] = add_noise(
|
| 426 |
+
img[:, :nchan_noise], poisson=poisson, blur=blur, ds_max=ds_max, iso=iso,
|
| 427 |
+
downsample=downsample, beta=beta, gblur=gblur,
|
| 428 |
+
diams=torch.from_numpy(random_diam).to(device).float())
|
| 429 |
+
# img -= img.mean(dim=(-2,-1), keepdim=True)
|
| 430 |
+
# img /= img.std(dim=(-2,-1), keepdim=True) + 1e-3
|
| 431 |
+
img = img.cpu().numpy()
|
| 432 |
+
|
| 433 |
+
# augmentations
|
| 434 |
+
img, lbl, scale = transforms.random_rotate_and_resize(
|
| 435 |
+
img,
|
| 436 |
+
Y=labels_new,
|
| 437 |
+
xy=xy,
|
| 438 |
+
rotate=False if not iso else rotate,
|
| 439 |
+
#(iso and downsample==0),
|
| 440 |
+
rescale=rsc,
|
| 441 |
+
scale_range=0.5)
|
| 442 |
+
img = torch.from_numpy(img).to(device)
|
| 443 |
+
lbl = torch.from_numpy(lbl).to(device)
|
| 444 |
+
|
| 445 |
+
return img, lbl, scale
|
| 446 |
+
|
| 447 |
+
|
| 448 |
+
def one_chan_cellpose(device, model_type="cyto2", pretrained_model=None):
|
| 449 |
+
"""
|
| 450 |
+
Creates a Cellpose network with a single input channel.
|
| 451 |
+
|
| 452 |
+
Args:
|
| 453 |
+
device (str): The device to run the network on.
|
| 454 |
+
model_type (str, optional): The type of Cellpose model to use. Defaults to "cyto2".
|
| 455 |
+
pretrained_model (str, optional): The path to a pretrained model file. Defaults to None.
|
| 456 |
+
|
| 457 |
+
Returns:
|
| 458 |
+
torch.nn.Module: The Cellpose network with a single input channel.
|
| 459 |
+
"""
|
| 460 |
+
if pretrained_model is not None and not os.path.exists(pretrained_model):
|
| 461 |
+
model_type = pretrained_model
|
| 462 |
+
pretrained_model = None
|
| 463 |
+
nbase = [32, 64, 128, 256]
|
| 464 |
+
nchan = 1
|
| 465 |
+
net1 = resnet_torch.CPnet([nchan, *nbase], nout=3, sz=3).to(device)
|
| 466 |
+
filename = model_path(model_type,
|
| 467 |
+
0) if pretrained_model is None else pretrained_model
|
| 468 |
+
weights = torch.load(filename, weights_only=True)
|
| 469 |
+
zp = 0
|
| 470 |
+
print(filename)
|
| 471 |
+
for name in net1.state_dict():
|
| 472 |
+
if ("res_down_0.conv.conv_0" not in name and
|
| 473 |
+
#"output" not in name and
|
| 474 |
+
"res_down_0.proj" not in name and name != "diam_mean" and
|
| 475 |
+
name != "diam_labels"):
|
| 476 |
+
net1.state_dict()[name].copy_(weights[name])
|
| 477 |
+
elif "res_down_0" in name:
|
| 478 |
+
if len(weights[name].shape) > 0:
|
| 479 |
+
new_weight = torch.zeros_like(net1.state_dict()[name])
|
| 480 |
+
if weights[name].shape[0] == 2:
|
| 481 |
+
new_weight[:] = weights[name][0]
|
| 482 |
+
elif len(weights[name].shape) > 1 and weights[name].shape[1] == 2:
|
| 483 |
+
new_weight[:, zp] = weights[name][:, 0]
|
| 484 |
+
else:
|
| 485 |
+
new_weight = weights[name]
|
| 486 |
+
else:
|
| 487 |
+
new_weight = weights[name]
|
| 488 |
+
net1.state_dict()[name].copy_(new_weight)
|
| 489 |
+
return net1
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class CellposeDenoiseModel():
|
| 493 |
+
""" model to run Cellpose and Image restoration """
|
| 494 |
+
|
| 495 |
+
def __init__(self, gpu=False, pretrained_model=False, model_type=None,
|
| 496 |
+
restore_type="denoise_cyto3", nchan=2,
|
| 497 |
+
chan2_restore=False, device=None):
|
| 498 |
+
|
| 499 |
+
self.dn = DenoiseModel(gpu=gpu, model_type=restore_type, chan2=chan2_restore,
|
| 500 |
+
device=device)
|
| 501 |
+
self.cp = CellposeModel(gpu=gpu, model_type=model_type, nchan=nchan,
|
| 502 |
+
pretrained_model=pretrained_model, device=device)
|
| 503 |
+
|
| 504 |
+
def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
|
| 505 |
+
normalize=True, rescale=None, diameter=None, tile_overlap=0.1,
|
| 506 |
+
augment=False, resample=True, invert=False, flow_threshold=0.4,
|
| 507 |
+
cellprob_threshold=0.0, do_3D=False, anisotropy=None, stitch_threshold=0.0,
|
| 508 |
+
min_size=15, niter=None, interp=True, bsize=224, flow3D_smooth=0):
|
| 509 |
+
"""
|
| 510 |
+
Restore array or list of images using the image restoration model, and then segment.
|
| 511 |
+
|
| 512 |
+
Args:
|
| 513 |
+
x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
|
| 514 |
+
batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
|
| 515 |
+
(can make smaller or bigger depending on GPU memory usage). Defaults to 8.
|
| 516 |
+
channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
|
| 517 |
+
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
|
| 518 |
+
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
|
| 519 |
+
For instance, to segment grayscale images, input [0,0]. To segment images with cells
|
| 520 |
+
in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
|
| 521 |
+
image with cells in green and nuclei in blue, input [[0,0], [2,3]].
|
| 522 |
+
Defaults to None.
|
| 523 |
+
channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
|
| 524 |
+
if None, channels dimension is attempted to be automatically determined. Defaults to None.
|
| 525 |
+
z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
|
| 526 |
+
if None, z dimension is attempted to be automatically determined. Defaults to None.
|
| 527 |
+
normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
|
| 528 |
+
can also pass dictionary of parameters (all keys are optional, default values shown):
|
| 529 |
+
- "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
|
| 530 |
+
- "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
|
| 531 |
+
- "normalize"=True ; run normalization (if False, all following parameters ignored)
|
| 532 |
+
- "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
|
| 533 |
+
- "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
|
| 534 |
+
- "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
| 535 |
+
Defaults to True.
|
| 536 |
+
rescale (float, optional): resize factor for each image, if None, set to 1.0;
|
| 537 |
+
(only used if diameter is None). Defaults to None.
|
| 538 |
+
diameter (float, optional): diameter for each image,
|
| 539 |
+
if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
|
| 540 |
+
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 541 |
+
augment (bool, optional): augment tiles by flipping and averaging for segmentation. Defaults to False.
|
| 542 |
+
resample (bool, optional): run dynamics at original image size (will be slower but create more accurate boundaries). Defaults to True.
|
| 543 |
+
invert (bool, optional): invert image pixel intensity before running network. Defaults to False.
|
| 544 |
+
flow_threshold (float, optional): flow error threshold (all cells with errors below threshold are kept) (not used for 3D). Defaults to 0.4.
|
| 545 |
+
cellprob_threshold (float, optional): all pixels with value above threshold kept for masks, decrease to find more and larger masks. Defaults to 0.0.
|
| 546 |
+
do_3D (bool, optional): set to True to run 3D segmentation on 3D/4D image input. Defaults to False.
|
| 547 |
+
anisotropy (float, optional): for 3D segmentation, optional rescaling factor (e.g. set to 2.0 if Z is sampled half as dense as X or Y). Defaults to None.
|
| 548 |
+
stitch_threshold (float, optional): if stitch_threshold>0.0 and not do_3D, masks are stitched in 3D to return volume segmentation. Defaults to 0.0.
|
| 549 |
+
min_size (int, optional): all ROIs below this size, in pixels, will be discarded. Defaults to 15.
|
| 550 |
+
flow3D_smooth (int, optional): if do_3D and flow3D_smooth>0, smooth flows with gaussian filter of this stddev. Defaults to 0.
|
| 551 |
+
niter (int, optional): number of iterations for dynamics computation. if None, it is set proportional to the diameter. Defaults to None.
|
| 552 |
+
interp (bool, optional): interpolate during 2D dynamics (not available in 3D) . Defaults to True.
|
| 553 |
+
|
| 554 |
+
Returns:
|
| 555 |
+
A tuple containing (masks, flows, styles, imgs); masks: labelled image(s), where 0=no masks; 1,2,...=mask labels;
|
| 556 |
+
flows: list of lists: flows[k][0] = XY flow in HSV 0-255; flows[k][1] = XY(Z) flows at each pixel; flows[k][2] = cell probability (if > cellprob_threshold, pixel used for dynamics); flows[k][3] = final pixel locations after Euler integration;
|
| 557 |
+
styles: style vector summarizing each image of size 256;
|
| 558 |
+
imgs: Restored images.
|
| 559 |
+
"""
|
| 560 |
+
|
| 561 |
+
if isinstance(normalize, dict):
|
| 562 |
+
normalize_params = {**normalize_default, **normalize}
|
| 563 |
+
elif not isinstance(normalize, bool):
|
| 564 |
+
raise ValueError("normalize parameter must be a bool or a dict")
|
| 565 |
+
else:
|
| 566 |
+
normalize_params = normalize_default
|
| 567 |
+
normalize_params["normalize"] = normalize
|
| 568 |
+
normalize_params["invert"] = invert
|
| 569 |
+
|
| 570 |
+
img_restore = self.dn.eval(x, batch_size=batch_size, channels=channels,
|
| 571 |
+
channel_axis=channel_axis, z_axis=z_axis,
|
| 572 |
+
do_3D=do_3D,
|
| 573 |
+
normalize=normalize_params, rescale=rescale,
|
| 574 |
+
diameter=diameter,
|
| 575 |
+
tile_overlap=tile_overlap, bsize=bsize)
|
| 576 |
+
|
| 577 |
+
# turn off special normalization for segmentation
|
| 578 |
+
normalize_params = normalize_default
|
| 579 |
+
|
| 580 |
+
# change channels for segmentation
|
| 581 |
+
if channels is not None:
|
| 582 |
+
channels_new = [0, 0] if channels[0] == 0 else [1, 2]
|
| 583 |
+
else:
|
| 584 |
+
channels_new = None
|
| 585 |
+
# change diameter if self.ratio > 1 (upsampled to self.dn.diam_mean)
|
| 586 |
+
diameter = self.dn.diam_mean if self.dn.ratio > 1 else diameter
|
| 587 |
+
masks, flows, styles = self.cp.eval(
|
| 588 |
+
img_restore, batch_size=batch_size, channels=channels_new, channel_axis=-1,
|
| 589 |
+
z_axis=0 if not isinstance(img_restore, list) and img_restore.ndim > 3 and img_restore.shape[0] > 0 else None,
|
| 590 |
+
normalize=normalize_params, rescale=rescale, diameter=diameter,
|
| 591 |
+
tile_overlap=tile_overlap, augment=augment, resample=resample,
|
| 592 |
+
invert=invert, flow_threshold=flow_threshold,
|
| 593 |
+
cellprob_threshold=cellprob_threshold, do_3D=do_3D, anisotropy=anisotropy,
|
| 594 |
+
stitch_threshold=stitch_threshold, min_size=min_size, niter=niter,
|
| 595 |
+
interp=interp, bsize=bsize)
|
| 596 |
+
|
| 597 |
+
return masks, flows, styles, img_restore
|
| 598 |
+
|
| 599 |
+
|
| 600 |
+
class DenoiseModel():
|
| 601 |
+
"""
|
| 602 |
+
DenoiseModel class for denoising images using Cellpose denoising model.
|
| 603 |
+
|
| 604 |
+
Args:
|
| 605 |
+
gpu (bool, optional): Whether to use GPU for computation. Defaults to False.
|
| 606 |
+
pretrained_model (bool or str or Path, optional): Pretrained model to use for denoising.
|
| 607 |
+
Can be a string or path. Defaults to False.
|
| 608 |
+
nchan (int, optional): Number of channels in the input images, all Cellpose 3 models were trained with nchan=1. Defaults to 1.
|
| 609 |
+
model_type (str, optional): Type of pretrained model to use ("denoise_cyto3", "deblur_cyto3", "upsample_cyto3", ...). Defaults to None.
|
| 610 |
+
chan2 (bool, optional): Whether to use a separate model for the second channel. Defaults to False.
|
| 611 |
+
diam_mean (float, optional): Mean diameter of the objects in the images. Defaults to 30.0.
|
| 612 |
+
device (torch.device, optional): Device to use for computation. Defaults to None.
|
| 613 |
+
|
| 614 |
+
Attributes:
|
| 615 |
+
nchan (int): Number of channels in the input images.
|
| 616 |
+
diam_mean (float): Mean diameter of the objects in the images.
|
| 617 |
+
net (CPnet): Cellpose network for denoising.
|
| 618 |
+
pretrained_model (bool or str or Path): Pretrained model path to use for denoising.
|
| 619 |
+
net_chan2 (CPnet or None): Cellpose network for the second channel, if applicable.
|
| 620 |
+
net_type (str): Type of the denoising network.
|
| 621 |
+
|
| 622 |
+
Methods:
|
| 623 |
+
eval(x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
|
| 624 |
+
normalize=True, rescale=None, diameter=None, tile=True, tile_overlap=0.1)
|
| 625 |
+
Denoise array or list of images using the denoising model.
|
| 626 |
+
|
| 627 |
+
_eval(net, x, normalize=True, rescale=None, diameter=None, tile=True,
|
| 628 |
+
tile_overlap=0.1)
|
| 629 |
+
Run denoising model on a single channel.
|
| 630 |
+
"""
|
| 631 |
+
|
| 632 |
+
def __init__(self, gpu=False, pretrained_model=False, nchan=1, model_type=None,
|
| 633 |
+
chan2=False, diam_mean=30., device=None):
|
| 634 |
+
self.nchan = nchan
|
| 635 |
+
if pretrained_model and (not isinstance(pretrained_model, str) and
|
| 636 |
+
not isinstance(pretrained_model, Path)):
|
| 637 |
+
raise ValueError("pretrained_model must be a string or path")
|
| 638 |
+
|
| 639 |
+
self.diam_mean = diam_mean
|
| 640 |
+
builtin = True
|
| 641 |
+
if model_type is not None or (pretrained_model and
|
| 642 |
+
not os.path.exists(pretrained_model)):
|
| 643 |
+
pretrained_model_string = model_type if model_type is not None else "denoise_cyto3"
|
| 644 |
+
if ~np.any([pretrained_model_string == s for s in MODEL_NAMES]):
|
| 645 |
+
pretrained_model_string = "denoise_cyto3"
|
| 646 |
+
pretrained_model = model_path(pretrained_model_string)
|
| 647 |
+
if (pretrained_model and not os.path.exists(pretrained_model)):
|
| 648 |
+
denoise_logger.warning("pretrained model has incorrect path")
|
| 649 |
+
denoise_logger.info(f">> {pretrained_model_string} << model set to be used")
|
| 650 |
+
self.diam_mean = 17. if "nuclei" in pretrained_model_string else 30.
|
| 651 |
+
else:
|
| 652 |
+
if pretrained_model:
|
| 653 |
+
builtin = False
|
| 654 |
+
pretrained_model_string = pretrained_model
|
| 655 |
+
denoise_logger.info(f">>>> loading model {pretrained_model_string}")
|
| 656 |
+
|
| 657 |
+
# assign network device
|
| 658 |
+
if device is None:
|
| 659 |
+
sdevice, gpu = assign_device(use_torch=True, gpu=gpu)
|
| 660 |
+
self.device = device if device is not None else sdevice
|
| 661 |
+
if device is not None:
|
| 662 |
+
device_gpu = self.device.type == "cuda"
|
| 663 |
+
self.gpu = gpu if device is None else device_gpu
|
| 664 |
+
|
| 665 |
+
# create network
|
| 666 |
+
self.nchan = nchan
|
| 667 |
+
self.nclasses = 1
|
| 668 |
+
nbase = [32, 64, 128, 256]
|
| 669 |
+
self.nchan = nchan
|
| 670 |
+
self.nbase = [nchan, *nbase]
|
| 671 |
+
|
| 672 |
+
self.net = CPnet(self.nbase, self.nclasses, sz=3,
|
| 673 |
+
max_pool=True, diam_mean=diam_mean).to(self.device)
|
| 674 |
+
|
| 675 |
+
self.pretrained_model = pretrained_model
|
| 676 |
+
self.net_chan2 = None
|
| 677 |
+
if self.pretrained_model:
|
| 678 |
+
self.net.load_model(self.pretrained_model, device=self.device)
|
| 679 |
+
denoise_logger.info(
|
| 680 |
+
f">>>> model diam_mean = {self.diam_mean: .3f} (ROIs rescaled to this size during training)"
|
| 681 |
+
)
|
| 682 |
+
if chan2 and builtin:
|
| 683 |
+
chan2_path = model_path(
|
| 684 |
+
os.path.split(self.pretrained_model)[-1].split("_")[0] + "_nuclei")
|
| 685 |
+
print(f"loading model for chan2: {os.path.split(str(chan2_path))[-1]}")
|
| 686 |
+
self.net_chan2 = CPnet(self.nbase, self.nclasses, sz=3,
|
| 687 |
+
max_pool=True,
|
| 688 |
+
diam_mean=17.).to(self.device)
|
| 689 |
+
self.net_chan2.load_model(chan2_path, device=self.device)
|
| 690 |
+
self.net_type = "cellpose_denoise"
|
| 691 |
+
|
| 692 |
+
def eval(self, x, batch_size=8, channels=None, channel_axis=None, z_axis=None,
|
| 693 |
+
normalize=True, rescale=None, diameter=None, tile=True, do_3D=False,
|
| 694 |
+
tile_overlap=0.1, bsize=224):
|
| 695 |
+
"""
|
| 696 |
+
Restore array or list of images using the image restoration model.
|
| 697 |
+
|
| 698 |
+
Args:
|
| 699 |
+
x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
|
| 700 |
+
batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
|
| 701 |
+
(can make smaller or bigger depending on GPU memory usage). Defaults to 8.
|
| 702 |
+
channels (list, optional): list of channels, either of length 2 or of length number of images by 2.
|
| 703 |
+
First element of list is the channel to segment (0=grayscale, 1=red, 2=green, 3=blue).
|
| 704 |
+
Second element of list is the optional nuclear channel (0=none, 1=red, 2=green, 3=blue).
|
| 705 |
+
For instance, to segment grayscale images, input [0,0]. To segment images with cells
|
| 706 |
+
in green and nuclei in blue, input [2,3]. To segment one grayscale image and one
|
| 707 |
+
image with cells in green and nuclei in blue, input [[0,0], [2,3]].
|
| 708 |
+
Defaults to None.
|
| 709 |
+
channel_axis (int, optional): channel axis in element of list x, or of np.ndarray x.
|
| 710 |
+
if None, channels dimension is attempted to be automatically determined. Defaults to None.
|
| 711 |
+
z_axis (int, optional): z axis in element of list x, or of np.ndarray x.
|
| 712 |
+
if None, z dimension is attempted to be automatically determined. Defaults to None.
|
| 713 |
+
normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
|
| 714 |
+
can also pass dictionary of parameters (all keys are optional, default values shown):
|
| 715 |
+
- "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
|
| 716 |
+
- "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
|
| 717 |
+
- "normalize"=True ; run normalization (if False, all following parameters ignored)
|
| 718 |
+
- "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
|
| 719 |
+
- "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
|
| 720 |
+
- "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
| 721 |
+
Defaults to True.
|
| 722 |
+
rescale (float, optional): resize factor for each image, if None, set to 1.0;
|
| 723 |
+
(only used if diameter is None). Defaults to None.
|
| 724 |
+
diameter (float, optional): diameter for each image,
|
| 725 |
+
if diameter is None, set to diam_mean or diam_train if available. Defaults to None.
|
| 726 |
+
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 727 |
+
|
| 728 |
+
Returns:
|
| 729 |
+
list: A list of 2D/3D arrays of restored images
|
| 730 |
+
|
| 731 |
+
"""
|
| 732 |
+
if isinstance(x, list) or x.squeeze().ndim == 5:
|
| 733 |
+
tqdm_out = utils.TqdmToLogger(denoise_logger, level=logging.INFO)
|
| 734 |
+
nimg = len(x)
|
| 735 |
+
iterator = trange(nimg, file=tqdm_out,
|
| 736 |
+
mininterval=30) if nimg > 1 else range(nimg)
|
| 737 |
+
imgs = []
|
| 738 |
+
for i in iterator:
|
| 739 |
+
imgi = self.eval(
|
| 740 |
+
x[i], batch_size=batch_size,
|
| 741 |
+
channels=channels[i] if channels is not None and
|
| 742 |
+
((len(channels) == len(x) and
|
| 743 |
+
(isinstance(channels[i], list) or
|
| 744 |
+
isinstance(channels[i], np.ndarray)) and len(channels[i]) == 2))
|
| 745 |
+
else channels, channel_axis=channel_axis, z_axis=z_axis,
|
| 746 |
+
normalize=normalize,
|
| 747 |
+
do_3D=do_3D,
|
| 748 |
+
rescale=rescale[i] if isinstance(rescale, list) or
|
| 749 |
+
isinstance(rescale, np.ndarray) else rescale,
|
| 750 |
+
diameter=diameter[i] if isinstance(diameter, list) or
|
| 751 |
+
isinstance(diameter, np.ndarray) else diameter,
|
| 752 |
+
tile_overlap=tile_overlap, bsize=bsize)
|
| 753 |
+
imgs.append(imgi)
|
| 754 |
+
if isinstance(x, np.ndarray):
|
| 755 |
+
imgs = np.array(imgs)
|
| 756 |
+
return imgs
|
| 757 |
+
|
| 758 |
+
else:
|
| 759 |
+
# reshape image
|
| 760 |
+
x = transforms.convert_image(x, channels, channel_axis=channel_axis,
|
| 761 |
+
z_axis=z_axis, do_3D=do_3D, nchan=None)
|
| 762 |
+
if x.ndim < 4:
|
| 763 |
+
squeeze = True
|
| 764 |
+
x = x[np.newaxis, ...]
|
| 765 |
+
else:
|
| 766 |
+
squeeze = False
|
| 767 |
+
|
| 768 |
+
# may need to interpolate image before running upsampling
|
| 769 |
+
self.ratio = 1.
|
| 770 |
+
if "upsample" in self.pretrained_model:
|
| 771 |
+
Ly, Lx = x.shape[-3:-1]
|
| 772 |
+
if diameter is not None and 3 <= diameter < self.diam_mean:
|
| 773 |
+
self.ratio = self.diam_mean / diameter
|
| 774 |
+
denoise_logger.info(
|
| 775 |
+
f"upsampling image to {self.diam_mean} pixel diameter ({self.ratio:0.2f} times)"
|
| 776 |
+
)
|
| 777 |
+
Lyr, Lxr = int(Ly * self.ratio), int(Lx * self.ratio)
|
| 778 |
+
x = transforms.resize_image(x, Ly=Lyr, Lx=Lxr)
|
| 779 |
+
else:
|
| 780 |
+
denoise_logger.warning(
|
| 781 |
+
f"not interpolating image before upsampling because diameter is set >= {self.diam_mean}"
|
| 782 |
+
)
|
| 783 |
+
#raise ValueError(f"diameter is set to {diameter}, needs to be >=3 and < {self.dn.diam_mean}")
|
| 784 |
+
|
| 785 |
+
self.batch_size = batch_size
|
| 786 |
+
|
| 787 |
+
if diameter is not None and diameter > 0:
|
| 788 |
+
rescale = self.diam_mean / diameter
|
| 789 |
+
elif rescale is None:
|
| 790 |
+
rescale = 1.0
|
| 791 |
+
|
| 792 |
+
if np.ptp(x[..., -1]) < 1e-3 or (channels is not None and channels[-1] == 0):
|
| 793 |
+
x = x[..., :1]
|
| 794 |
+
|
| 795 |
+
for c in range(x.shape[-1]):
|
| 796 |
+
rescale0 = rescale * 30. / 17. if c == 1 else rescale
|
| 797 |
+
if c == 0 or self.net_chan2 is None:
|
| 798 |
+
x[...,
|
| 799 |
+
c] = self._eval(self.net, x[..., c:c + 1], batch_size=batch_size,
|
| 800 |
+
normalize=normalize, rescale=rescale0,
|
| 801 |
+
tile_overlap=tile_overlap, bsize=bsize)[...,0]
|
| 802 |
+
else:
|
| 803 |
+
x[...,
|
| 804 |
+
c] = self._eval(self.net_chan2, x[...,
|
| 805 |
+
c:c + 1], batch_size=batch_size,
|
| 806 |
+
normalize=normalize, rescale=rescale0,
|
| 807 |
+
tile_overlap=tile_overlap, bsize=bsize)[...,0]
|
| 808 |
+
x = x[0] if squeeze else x
|
| 809 |
+
return x
|
| 810 |
+
|
| 811 |
+
def _eval(self, net, x, batch_size=8, normalize=True, rescale=None,
|
| 812 |
+
tile_overlap=0.1, bsize=224):
|
| 813 |
+
"""
|
| 814 |
+
Run image restoration model on a single channel.
|
| 815 |
+
|
| 816 |
+
Args:
|
| 817 |
+
x (list, np.ndarry): can be list of 2D/3D/4D images, or array of 2D/3D/4D images
|
| 818 |
+
batch_size (int, optional): number of 224x224 patches to run simultaneously on the GPU
|
| 819 |
+
(can make smaller or bigger depending on GPU memory usage). Defaults to 8.
|
| 820 |
+
normalize (bool, optional): if True, normalize data so 0.0=1st percentile and 1.0=99th percentile of image intensities in each channel;
|
| 821 |
+
can also pass dictionary of parameters (all keys are optional, default values shown):
|
| 822 |
+
- "lowhigh"=None : pass in normalization values for 0.0 and 1.0 as list [low, high] (if not None, all following parameters ignored)
|
| 823 |
+
- "sharpen"=0 ; sharpen image with high pass filter, recommended to be 1/4-1/8 diameter of cells in pixels
|
| 824 |
+
- "normalize"=True ; run normalization (if False, all following parameters ignored)
|
| 825 |
+
- "percentile"=None : pass in percentiles to use as list [perc_low, perc_high]
|
| 826 |
+
- "tile_norm"=0 ; compute normalization in tiles across image to brighten dark areas, to turn on set to window size in pixels (e.g. 100)
|
| 827 |
+
- "norm3D"=False ; compute normalization across entire z-stack rather than plane-by-plane in stitching mode.
|
| 828 |
+
Defaults to True.
|
| 829 |
+
rescale (float, optional): resize factor for each image, if None, set to 1.0;
|
| 830 |
+
(only used if diameter is None). Defaults to None.
|
| 831 |
+
tile_overlap (float, optional): fraction of overlap of tiles when computing flows. Defaults to 0.1.
|
| 832 |
+
|
| 833 |
+
Returns:
|
| 834 |
+
list: A list of 2D/3D arrays of restored images
|
| 835 |
+
|
| 836 |
+
"""
|
| 837 |
+
if isinstance(normalize, dict):
|
| 838 |
+
normalize_params = {**normalize_default, **normalize}
|
| 839 |
+
elif not isinstance(normalize, bool):
|
| 840 |
+
raise ValueError("normalize parameter must be a bool or a dict")
|
| 841 |
+
else:
|
| 842 |
+
normalize_params = normalize_default
|
| 843 |
+
normalize_params["normalize"] = normalize
|
| 844 |
+
|
| 845 |
+
tic = time.time()
|
| 846 |
+
shape = x.shape
|
| 847 |
+
nimg = shape[0]
|
| 848 |
+
|
| 849 |
+
do_normalization = True if normalize_params["normalize"] else False
|
| 850 |
+
|
| 851 |
+
img = np.asarray(x)
|
| 852 |
+
if do_normalization:
|
| 853 |
+
img = transforms.normalize_img(img, **normalize_params)
|
| 854 |
+
if rescale != 1.0:
|
| 855 |
+
img = transforms.resize_image(img, rsz=rescale)
|
| 856 |
+
yf, style = run_net(self.net, img, bsize=bsize,
|
| 857 |
+
tile_overlap=tile_overlap)
|
| 858 |
+
yf = transforms.resize_image(yf, shape[1], shape[2])
|
| 859 |
+
imgs = yf
|
| 860 |
+
del yf, style
|
| 861 |
+
|
| 862 |
+
# imgs = np.zeros((*x.shape[:-1], 1), np.float32)
|
| 863 |
+
# for i in iterator:
|
| 864 |
+
# img = np.asarray(x[i])
|
| 865 |
+
# if do_normalization:
|
| 866 |
+
# img = transforms.normalize_img(img, **normalize_params)
|
| 867 |
+
# if rescale != 1.0:
|
| 868 |
+
# img = transforms.resize_image(img, rsz=[rescale, rescale])
|
| 869 |
+
# if img.ndim == 2:
|
| 870 |
+
# img = img[:, :, np.newaxis]
|
| 871 |
+
# yf, style = run_net(net, img, batch_size=batch_size, augment=False,
|
| 872 |
+
# tile=tile, tile_overlap=tile_overlap, bsize=bsize)
|
| 873 |
+
# img = transforms.resize_image(yf, Ly=x.shape[-3], Lx=x.shape[-2])
|
| 874 |
+
|
| 875 |
+
# if img.ndim == 2:
|
| 876 |
+
# img = img[:, :, np.newaxis]
|
| 877 |
+
# imgs[i] = img
|
| 878 |
+
# del yf, style
|
| 879 |
+
net_time = time.time() - tic
|
| 880 |
+
if nimg > 1:
|
| 881 |
+
denoise_logger.info("imgs denoised in %2.2fs" % (net_time))
|
| 882 |
+
|
| 883 |
+
return imgs
|
| 884 |
+
|
| 885 |
+
|
| 886 |
+
def train(net, train_data=None, train_labels=None, train_files=None, test_data=None,
|
| 887 |
+
test_labels=None, test_files=None, train_probs=None, test_probs=None,
|
| 888 |
+
lam=[1., 1.5, 0.], scale_range=0.5, seg_model_type="cyto2", save_path=None,
|
| 889 |
+
save_every=100, save_each=False, poisson=0.7, beta=0.7, blur=0.7, gblur=1.0,
|
| 890 |
+
iso=True, uniform_blur=False, downsample=0., ds_max=7,
|
| 891 |
+
learning_rate=0.005, n_epochs=500,
|
| 892 |
+
weight_decay=0.00001, batch_size=8, nimg_per_epoch=None,
|
| 893 |
+
nimg_test_per_epoch=None, model_name=None):
|
| 894 |
+
|
| 895 |
+
# net properties
|
| 896 |
+
device = net.device
|
| 897 |
+
nchan = net.nchan
|
| 898 |
+
diam_mean = net.diam_mean.item()
|
| 899 |
+
|
| 900 |
+
args = np.array([poisson, beta, blur, gblur, downsample])
|
| 901 |
+
if args.ndim == 1:
|
| 902 |
+
args = args[:, np.newaxis]
|
| 903 |
+
poisson, beta, blur, gblur, downsample = args
|
| 904 |
+
nnoise = len(poisson)
|
| 905 |
+
|
| 906 |
+
d = datetime.datetime.now()
|
| 907 |
+
if save_path is not None:
|
| 908 |
+
if model_name is None:
|
| 909 |
+
filename = ""
|
| 910 |
+
lstrs = ["per", "seg", "rec"]
|
| 911 |
+
for k, (l, s) in enumerate(zip(lam, lstrs)):
|
| 912 |
+
filename += f"{s}_{l:.2f}_"
|
| 913 |
+
if not iso:
|
| 914 |
+
filename += "aniso_"
|
| 915 |
+
if poisson.sum() > 0:
|
| 916 |
+
filename += "poisson_"
|
| 917 |
+
if blur.sum() > 0:
|
| 918 |
+
filename += "blur_"
|
| 919 |
+
if downsample.sum() > 0:
|
| 920 |
+
filename += "downsample_"
|
| 921 |
+
filename += d.strftime("%Y_%m_%d_%H_%M_%S.%f")
|
| 922 |
+
filename = os.path.join(save_path, filename)
|
| 923 |
+
else:
|
| 924 |
+
filename = os.path.join(save_path, model_name)
|
| 925 |
+
print(filename)
|
| 926 |
+
for i in range(len(poisson)):
|
| 927 |
+
denoise_logger.info(
|
| 928 |
+
f"poisson: {poisson[i]: 0.2f}, beta: {beta[i]: 0.2f}, blur: {blur[i]: 0.2f}, gblur: {gblur[i]: 0.2f}, downsample: {downsample[i]: 0.2f}"
|
| 929 |
+
)
|
| 930 |
+
net1 = one_chan_cellpose(device=device, pretrained_model=seg_model_type)
|
| 931 |
+
|
| 932 |
+
learning_rate_const = learning_rate
|
| 933 |
+
LR = np.linspace(0, learning_rate_const, 10)
|
| 934 |
+
LR = np.append(LR, learning_rate_const * np.ones(n_epochs - 100))
|
| 935 |
+
for i in range(10):
|
| 936 |
+
LR = np.append(LR, LR[-1] / 2 * np.ones(10))
|
| 937 |
+
learning_rate = LR
|
| 938 |
+
|
| 939 |
+
batch_size = 8
|
| 940 |
+
optimizer = torch.optim.AdamW(net.parameters(), lr=learning_rate[0],
|
| 941 |
+
weight_decay=weight_decay)
|
| 942 |
+
if train_data is not None:
|
| 943 |
+
nimg = len(train_data)
|
| 944 |
+
diam_train = np.array(
|
| 945 |
+
[utils.diameters(train_labels[k])[0] for k in trange(len(train_labels))])
|
| 946 |
+
diam_train[diam_train < 5] = 5.
|
| 947 |
+
if test_data is not None:
|
| 948 |
+
diam_test = np.array(
|
| 949 |
+
[utils.diameters(test_labels[k])[0] for k in trange(len(test_labels))])
|
| 950 |
+
diam_test[diam_test < 5] = 5.
|
| 951 |
+
nimg_test = len(test_data)
|
| 952 |
+
else:
|
| 953 |
+
nimg = len(train_files)
|
| 954 |
+
denoise_logger.info(">>> using files instead of loading dataset")
|
| 955 |
+
train_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in train_files]
|
| 956 |
+
denoise_logger.info(">>> computing diameters")
|
| 957 |
+
diam_train = np.array([
|
| 958 |
+
utils.diameters(io.imread(train_labels_files[k])[0])[0]
|
| 959 |
+
for k in trange(len(train_labels_files))
|
| 960 |
+
])
|
| 961 |
+
diam_train[diam_train < 5] = 5.
|
| 962 |
+
if test_files is not None:
|
| 963 |
+
nimg_test = len(test_files)
|
| 964 |
+
test_labels_files = [str(tf)[:-4] + f"_flows.tif" for tf in test_files]
|
| 965 |
+
diam_test = np.array([
|
| 966 |
+
utils.diameters(io.imread(test_labels_files[k])[0])[0]
|
| 967 |
+
for k in trange(len(test_labels_files))
|
| 968 |
+
])
|
| 969 |
+
diam_test[diam_test < 5] = 5.
|
| 970 |
+
train_probs = 1. / nimg * np.ones(nimg,
|
| 971 |
+
"float64") if train_probs is None else train_probs
|
| 972 |
+
if test_files is not None or test_data is not None:
|
| 973 |
+
test_probs = 1. / nimg_test * np.ones(
|
| 974 |
+
nimg_test, "float64") if test_probs is None else test_probs
|
| 975 |
+
|
| 976 |
+
tic = time.time()
|
| 977 |
+
|
| 978 |
+
nimg_per_epoch = nimg if nimg_per_epoch is None else nimg_per_epoch
|
| 979 |
+
if test_files is not None or test_data is not None:
|
| 980 |
+
nimg_test_per_epoch = nimg_test if nimg_test_per_epoch is None else nimg_test_per_epoch
|
| 981 |
+
|
| 982 |
+
nbatch = 0
|
| 983 |
+
train_losses, test_losses = [], []
|
| 984 |
+
for iepoch in range(n_epochs):
|
| 985 |
+
np.random.seed(iepoch)
|
| 986 |
+
rperm = np.random.choice(np.arange(0, nimg), size=(nimg_per_epoch,),
|
| 987 |
+
p=train_probs)
|
| 988 |
+
torch.manual_seed(iepoch)
|
| 989 |
+
np.random.seed(iepoch)
|
| 990 |
+
for param_group in optimizer.param_groups:
|
| 991 |
+
param_group["lr"] = learning_rate[iepoch]
|
| 992 |
+
lavg, lavg_per, nsum = 0, 0, 0
|
| 993 |
+
for ibatch in range(0, nimg_per_epoch, batch_size * nnoise):
|
| 994 |
+
inds = rperm[ibatch : ibatch + batch_size * nnoise]
|
| 995 |
+
if train_data is None:
|
| 996 |
+
imgs = [np.maximum(0, io.imread(train_files[i])[:nchan]) for i in inds]
|
| 997 |
+
lbls = [io.imread(train_labels_files[i])[1:] for i in inds]
|
| 998 |
+
else:
|
| 999 |
+
imgs = [train_data[i][:nchan] for i in inds]
|
| 1000 |
+
lbls = [train_labels[i][1:] for i in inds]
|
| 1001 |
+
#inoise = nbatch % nnoise
|
| 1002 |
+
rnoise = np.random.permutation(nnoise)
|
| 1003 |
+
for i, inoise in enumerate(rnoise):
|
| 1004 |
+
if i * batch_size < len(imgs):
|
| 1005 |
+
imgi, lbli, scale = random_rotate_and_resize_noise(
|
| 1006 |
+
imgs[i * batch_size : (i + 1) * batch_size],
|
| 1007 |
+
lbls[i * batch_size : (i + 1) * batch_size],
|
| 1008 |
+
diam_train[inds][i * batch_size : (i + 1) * batch_size].copy(),
|
| 1009 |
+
poisson=poisson[inoise],
|
| 1010 |
+
beta=beta[inoise], gblur=gblur[inoise], blur=blur[inoise], iso=iso,
|
| 1011 |
+
downsample=downsample[inoise], uniform_blur=uniform_blur,
|
| 1012 |
+
diam_mean=diam_mean, ds_max=ds_max,
|
| 1013 |
+
device=device)
|
| 1014 |
+
if i == 0:
|
| 1015 |
+
img = imgi
|
| 1016 |
+
lbl = lbli
|
| 1017 |
+
else:
|
| 1018 |
+
img = torch.cat((img, imgi), axis=0)
|
| 1019 |
+
lbl = torch.cat((lbl, lbli), axis=0)
|
| 1020 |
+
|
| 1021 |
+
if nnoise > 0:
|
| 1022 |
+
iperm = np.random.permutation(img.shape[0])
|
| 1023 |
+
img, lbl = img[iperm], lbl[iperm]
|
| 1024 |
+
|
| 1025 |
+
for i in range(nnoise):
|
| 1026 |
+
optimizer.zero_grad()
|
| 1027 |
+
imgi = img[i * batch_size: (i + 1) * batch_size]
|
| 1028 |
+
lbli = lbl[i * batch_size: (i + 1) * batch_size]
|
| 1029 |
+
if imgi.shape[0] > 0:
|
| 1030 |
+
loss, loss_per = train_loss(net, imgi[:, :nchan], net1=net1,
|
| 1031 |
+
img=imgi[:, nchan:], lbl=lbli, lam=lam)
|
| 1032 |
+
loss.backward()
|
| 1033 |
+
optimizer.step()
|
| 1034 |
+
lavg += loss.item() * imgi.shape[0]
|
| 1035 |
+
lavg_per += loss_per.item() * imgi.shape[0]
|
| 1036 |
+
|
| 1037 |
+
nsum += len(img)
|
| 1038 |
+
nbatch += 1
|
| 1039 |
+
|
| 1040 |
+
if iepoch % 5 == 0 or iepoch < 10:
|
| 1041 |
+
lavg = lavg / nsum
|
| 1042 |
+
lavg_per = lavg_per / nsum
|
| 1043 |
+
if test_data is not None or test_files is not None:
|
| 1044 |
+
lavgt, nsum = 0., 0
|
| 1045 |
+
np.random.seed(42)
|
| 1046 |
+
rperm = np.random.choice(np.arange(0, nimg_test),
|
| 1047 |
+
size=(nimg_test_per_epoch,), p=test_probs)
|
| 1048 |
+
inoise = iepoch % nnoise
|
| 1049 |
+
torch.manual_seed(inoise)
|
| 1050 |
+
for ibatch in range(0, nimg_test_per_epoch, batch_size):
|
| 1051 |
+
inds = rperm[ibatch:ibatch + batch_size]
|
| 1052 |
+
if test_data is None:
|
| 1053 |
+
imgs = [
|
| 1054 |
+
np.maximum(0,
|
| 1055 |
+
io.imread(test_files[i])[:nchan]) for i in inds
|
| 1056 |
+
]
|
| 1057 |
+
lbls = [io.imread(test_labels_files[i])[1:] for i in inds]
|
| 1058 |
+
else:
|
| 1059 |
+
imgs = [test_data[i][:nchan] for i in inds]
|
| 1060 |
+
lbls = [test_labels[i][1:] for i in inds]
|
| 1061 |
+
img, lbl, scale = random_rotate_and_resize_noise(
|
| 1062 |
+
imgs, lbls, diam_test[inds].copy(), poisson=poisson[inoise],
|
| 1063 |
+
beta=beta[inoise], blur=blur[inoise], gblur=gblur[inoise],
|
| 1064 |
+
iso=iso, downsample=downsample[inoise], uniform_blur=uniform_blur,
|
| 1065 |
+
diam_mean=diam_mean, ds_max=ds_max, device=device)
|
| 1066 |
+
loss, loss_per = test_loss(net, img[:, :nchan], net1=net1,
|
| 1067 |
+
img=img[:, nchan:], lbl=lbl, lam=lam)
|
| 1068 |
+
|
| 1069 |
+
lavgt += loss.item() * img.shape[0]
|
| 1070 |
+
nsum += len(img)
|
| 1071 |
+
lavgt = lavgt / nsum
|
| 1072 |
+
denoise_logger.info(
|
| 1073 |
+
"Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, Loss Test %0.3f, LR %2.4f"
|
| 1074 |
+
% (iepoch, time.time() - tic, lavg, lavg_per, lavgt,
|
| 1075 |
+
learning_rate[iepoch]))
|
| 1076 |
+
test_losses.append(lavgt)
|
| 1077 |
+
else:
|
| 1078 |
+
denoise_logger.info(
|
| 1079 |
+
"Epoch %d, Time %4.1fs, Loss %0.3f, loss_per %0.3f, LR %2.4f" %
|
| 1080 |
+
(iepoch, time.time() - tic, lavg, lavg_per, learning_rate[iepoch]))
|
| 1081 |
+
train_losses.append(lavg)
|
| 1082 |
+
|
| 1083 |
+
if save_path is not None:
|
| 1084 |
+
if iepoch == n_epochs - 1 or (iepoch % save_every == 0 and iepoch != 0):
|
| 1085 |
+
if save_each: #separate files as model progresses
|
| 1086 |
+
filename0 = str(filename) + f"_epoch_{iepoch:%04d}"
|
| 1087 |
+
else:
|
| 1088 |
+
filename0 = filename
|
| 1089 |
+
denoise_logger.info(f"saving network parameters to {filename0}")
|
| 1090 |
+
net.save_model(filename0)
|
| 1091 |
+
else:
|
| 1092 |
+
filename = save_path
|
| 1093 |
+
|
| 1094 |
+
return filename, train_losses, test_losses
|
| 1095 |
+
|
| 1096 |
+
|
| 1097 |
+
if __name__ == "__main__":
|
| 1098 |
+
import argparse
|
| 1099 |
+
parser = argparse.ArgumentParser(description="cellpose parameters")
|
| 1100 |
+
|
| 1101 |
+
input_img_args = parser.add_argument_group("input image arguments")
|
| 1102 |
+
input_img_args.add_argument("--dir", default=[], type=str,
|
| 1103 |
+
help="folder containing data to run or train on.")
|
| 1104 |
+
input_img_args.add_argument("--img_filter", default=[], type=str,
|
| 1105 |
+
help="end string for images to run on")
|
| 1106 |
+
|
| 1107 |
+
model_args = parser.add_argument_group("model arguments")
|
| 1108 |
+
model_args.add_argument("--pretrained_model", default=[], type=str,
|
| 1109 |
+
help="pretrained denoising model")
|
| 1110 |
+
|
| 1111 |
+
training_args = parser.add_argument_group("training arguments")
|
| 1112 |
+
training_args.add_argument("--test_dir", default=[], type=str,
|
| 1113 |
+
help="folder containing test data (optional)")
|
| 1114 |
+
training_args.add_argument("--file_list", default=[], type=str,
|
| 1115 |
+
help="npy file containing list of train and test files")
|
| 1116 |
+
training_args.add_argument("--seg_model_type", default="cyto2", type=str,
|
| 1117 |
+
help="model to use for seg training loss")
|
| 1118 |
+
training_args.add_argument(
|
| 1119 |
+
"--noise_type", default=[], type=str,
|
| 1120 |
+
help="noise type to use (if input, then other noise params are ignored)")
|
| 1121 |
+
training_args.add_argument("--poisson", default=0.8, type=float,
|
| 1122 |
+
help="fraction of images to add poisson noise to")
|
| 1123 |
+
training_args.add_argument("--beta", default=0.7, type=float,
|
| 1124 |
+
help="scale of poisson noise")
|
| 1125 |
+
training_args.add_argument("--blur", default=0., type=float,
|
| 1126 |
+
help="fraction of images to blur")
|
| 1127 |
+
training_args.add_argument("--gblur", default=1.0, type=float,
|
| 1128 |
+
help="scale of gaussian blurring stddev")
|
| 1129 |
+
training_args.add_argument("--downsample", default=0., type=float,
|
| 1130 |
+
help="fraction of images to downsample")
|
| 1131 |
+
training_args.add_argument("--ds_max", default=7, type=int,
|
| 1132 |
+
help="max downsampling factor")
|
| 1133 |
+
training_args.add_argument("--lam_per", default=1.0, type=float,
|
| 1134 |
+
help="weighting of perceptual loss")
|
| 1135 |
+
training_args.add_argument("--lam_seg", default=1.5, type=float,
|
| 1136 |
+
help="weighting of segmentation loss")
|
| 1137 |
+
training_args.add_argument("--lam_rec", default=0., type=float,
|
| 1138 |
+
help="weighting of reconstruction loss")
|
| 1139 |
+
training_args.add_argument(
|
| 1140 |
+
"--diam_mean", default=30., type=float, help=
|
| 1141 |
+
"mean diameter to resize cells to during training -- if starting from pretrained models it cannot be changed from 30.0"
|
| 1142 |
+
)
|
| 1143 |
+
training_args.add_argument("--learning_rate", default=0.001, type=float,
|
| 1144 |
+
help="learning rate. Default: %(default)s")
|
| 1145 |
+
training_args.add_argument("--n_epochs", default=2000, type=int,
|
| 1146 |
+
help="number of epochs. Default: %(default)s")
|
| 1147 |
+
training_args.add_argument(
|
| 1148 |
+
"--save_each", default=False, action="store_true",
|
| 1149 |
+
help="save each epoch as separate model")
|
| 1150 |
+
training_args.add_argument(
|
| 1151 |
+
"--nimg_per_epoch", default=0, type=int,
|
| 1152 |
+
help="number of images per epoch. Default is length of training images")
|
| 1153 |
+
training_args.add_argument(
|
| 1154 |
+
"--nimg_test_per_epoch", default=0, type=int,
|
| 1155 |
+
help="number of test images per epoch. Default is length of testing images")
|
| 1156 |
+
|
| 1157 |
+
io.logger_setup()
|
| 1158 |
+
|
| 1159 |
+
args = parser.parse_args()
|
| 1160 |
+
lams = [args.lam_per, args.lam_seg, args.lam_rec]
|
| 1161 |
+
print("lam", lams)
|
| 1162 |
+
|
| 1163 |
+
if len(args.noise_type) > 0:
|
| 1164 |
+
noise_type = args.noise_type
|
| 1165 |
+
uniform_blur = False
|
| 1166 |
+
iso = True
|
| 1167 |
+
if noise_type == "poisson":
|
| 1168 |
+
poisson = 0.8
|
| 1169 |
+
blur = 0.
|
| 1170 |
+
downsample = 0.
|
| 1171 |
+
beta = 0.7
|
| 1172 |
+
gblur = 1.0
|
| 1173 |
+
elif noise_type == "blur_expr":
|
| 1174 |
+
poisson = 0.8
|
| 1175 |
+
blur = 0.8
|
| 1176 |
+
downsample = 0.
|
| 1177 |
+
beta = 0.1
|
| 1178 |
+
gblur = 0.5
|
| 1179 |
+
elif noise_type == "blur":
|
| 1180 |
+
poisson = 0.8
|
| 1181 |
+
blur = 0.8
|
| 1182 |
+
downsample = 0.
|
| 1183 |
+
beta = 0.1
|
| 1184 |
+
gblur = 10.0
|
| 1185 |
+
uniform_blur = True
|
| 1186 |
+
elif noise_type == "downsample_expr":
|
| 1187 |
+
poisson = 0.8
|
| 1188 |
+
blur = 0.8
|
| 1189 |
+
downsample = 0.8
|
| 1190 |
+
beta = 0.03
|
| 1191 |
+
gblur = 1.0
|
| 1192 |
+
elif noise_type == "downsample":
|
| 1193 |
+
poisson = 0.8
|
| 1194 |
+
blur = 0.8
|
| 1195 |
+
downsample = 0.8
|
| 1196 |
+
beta = 0.03
|
| 1197 |
+
gblur = 5.0
|
| 1198 |
+
uniform_blur = True
|
| 1199 |
+
elif noise_type == "all":
|
| 1200 |
+
poisson = [0.8, 0.8, 0.8]
|
| 1201 |
+
blur = [0., 0.8, 0.8]
|
| 1202 |
+
downsample = [0., 0., 0.8]
|
| 1203 |
+
beta = [0.7, 0.1, 0.03]
|
| 1204 |
+
gblur = [0., 10.0, 5.0]
|
| 1205 |
+
uniform_blur = True
|
| 1206 |
+
elif noise_type == "aniso":
|
| 1207 |
+
poisson = 0.8
|
| 1208 |
+
blur = 0.8
|
| 1209 |
+
downsample = 0.8
|
| 1210 |
+
beta = 0.1
|
| 1211 |
+
gblur = args.ds_max * 1.5
|
| 1212 |
+
iso = False
|
| 1213 |
+
else:
|
| 1214 |
+
raise ValueError(f"{noise_type} noise_type is not supported")
|
| 1215 |
+
else:
|
| 1216 |
+
poisson, beta = args.poisson, args.beta
|
| 1217 |
+
blur, gblur = args.blur, args.gblur
|
| 1218 |
+
downsample = args.downsample
|
| 1219 |
+
|
| 1220 |
+
pretrained_model = None if len(
|
| 1221 |
+
args.pretrained_model) == 0 else args.pretrained_model
|
| 1222 |
+
model = DenoiseModel(gpu=True, nchan=1, diam_mean=args.diam_mean,
|
| 1223 |
+
pretrained_model=pretrained_model)
|
| 1224 |
+
|
| 1225 |
+
train_data, labels, train_files, train_probs = None, None, None, None
|
| 1226 |
+
test_data, test_labels, test_files, test_probs = None, None, None, None
|
| 1227 |
+
if len(args.file_list) == 0:
|
| 1228 |
+
output = io.load_train_test_data(args.dir, args.test_dir, "_img", "_masks", 0)
|
| 1229 |
+
images, labels, image_names, test_images, test_labels, image_names_test = output
|
| 1230 |
+
train_data = []
|
| 1231 |
+
for i in range(len(images)):
|
| 1232 |
+
img = images[i].astype("float32")
|
| 1233 |
+
if img.ndim > 2:
|
| 1234 |
+
img = img[0]
|
| 1235 |
+
train_data.append(
|
| 1236 |
+
np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
|
| 1237 |
+
if len(args.test_dir) > 0:
|
| 1238 |
+
test_data = []
|
| 1239 |
+
for i in range(len(test_images)):
|
| 1240 |
+
img = test_images[i].astype("float32")
|
| 1241 |
+
if img.ndim > 2:
|
| 1242 |
+
img = img[0]
|
| 1243 |
+
test_data.append(
|
| 1244 |
+
np.maximum(transforms.normalize99(img), 0)[np.newaxis, :, :])
|
| 1245 |
+
save_path = os.path.join(args.dir, "../models/")
|
| 1246 |
+
else:
|
| 1247 |
+
root = args.dir
|
| 1248 |
+
denoise_logger.info(
|
| 1249 |
+
">>> using file_list (assumes images are normalized and have flows!)")
|
| 1250 |
+
dat = np.load(args.file_list, allow_pickle=True).item()
|
| 1251 |
+
train_files = dat["train_files"]
|
| 1252 |
+
test_files = dat["test_files"]
|
| 1253 |
+
train_probs = dat["train_probs"] if "train_probs" in dat else None
|
| 1254 |
+
test_probs = dat["test_probs"] if "test_probs" in dat else None
|
| 1255 |
+
if str(train_files[0])[:len(str(root))] != str(root):
|
| 1256 |
+
for i in range(len(train_files)):
|
| 1257 |
+
new_path = root / Path(*train_files[i].parts[-3:])
|
| 1258 |
+
if i == 0:
|
| 1259 |
+
print(f"changing path from {train_files[i]} to {new_path}")
|
| 1260 |
+
train_files[i] = new_path
|
| 1261 |
+
|
| 1262 |
+
for i in range(len(test_files)):
|
| 1263 |
+
new_path = root / Path(*test_files[i].parts[-3:])
|
| 1264 |
+
test_files[i] = new_path
|
| 1265 |
+
save_path = os.path.join(args.dir, "models/")
|
| 1266 |
+
|
| 1267 |
+
os.makedirs(save_path, exist_ok=True)
|
| 1268 |
+
|
| 1269 |
+
nimg_per_epoch = None if args.nimg_per_epoch == 0 else args.nimg_per_epoch
|
| 1270 |
+
nimg_test_per_epoch = None if args.nimg_test_per_epoch == 0 else args.nimg_test_per_epoch
|
| 1271 |
+
|
| 1272 |
+
model_path = train(
|
| 1273 |
+
model.net, train_data=train_data, train_labels=labels, train_files=train_files,
|
| 1274 |
+
test_data=test_data, test_labels=test_labels, test_files=test_files,
|
| 1275 |
+
train_probs=train_probs, test_probs=test_probs, poisson=poisson, beta=beta,
|
| 1276 |
+
blur=blur, gblur=gblur, downsample=downsample, ds_max=args.ds_max,
|
| 1277 |
+
iso=iso, uniform_blur=uniform_blur, n_epochs=args.n_epochs,
|
| 1278 |
+
learning_rate=args.learning_rate,
|
| 1279 |
+
lam=lams,
|
| 1280 |
+
seg_model_type=args.seg_model_type, nimg_per_epoch=nimg_per_epoch,
|
| 1281 |
+
nimg_test_per_epoch=nimg_test_per_epoch, save_path=save_path)
|
| 1282 |
+
|
| 1283 |
+
|
| 1284 |
+
def seg_train_noisy(model, train_data, train_labels, test_data=None, test_labels=None,
|
| 1285 |
+
poisson=0.8, blur=0.0, downsample=0.0, save_path=None,
|
| 1286 |
+
save_every=100, save_each=False, learning_rate=0.2, n_epochs=500,
|
| 1287 |
+
momentum=0.9, weight_decay=0.00001, SGD=True, batch_size=8,
|
| 1288 |
+
nimg_per_epoch=None, diameter=None, rescale=True, z_masking=False,
|
| 1289 |
+
model_name=None):
|
| 1290 |
+
""" train function uses loss function model.loss_fn in models.py
|
| 1291 |
+
|
| 1292 |
+
(data should already be normalized)
|
| 1293 |
+
|
| 1294 |
+
"""
|
| 1295 |
+
|
| 1296 |
+
d = datetime.datetime.now()
|
| 1297 |
+
|
| 1298 |
+
model.n_epochs = n_epochs
|
| 1299 |
+
if isinstance(learning_rate, (list, np.ndarray)):
|
| 1300 |
+
if isinstance(learning_rate, np.ndarray) and learning_rate.ndim > 1:
|
| 1301 |
+
raise ValueError("learning_rate.ndim must equal 1")
|
| 1302 |
+
elif len(learning_rate) != n_epochs:
|
| 1303 |
+
raise ValueError(
|
| 1304 |
+
"if learning_rate given as list or np.ndarray it must have length n_epochs"
|
| 1305 |
+
)
|
| 1306 |
+
model.learning_rate = learning_rate
|
| 1307 |
+
model.learning_rate_const = mode(learning_rate)[0][0]
|
| 1308 |
+
else:
|
| 1309 |
+
model.learning_rate_const = learning_rate
|
| 1310 |
+
# set learning rate schedule
|
| 1311 |
+
if SGD:
|
| 1312 |
+
LR = np.linspace(0, model.learning_rate_const, 10)
|
| 1313 |
+
if model.n_epochs > 250:
|
| 1314 |
+
LR = np.append(
|
| 1315 |
+
LR, model.learning_rate_const * np.ones(model.n_epochs - 100))
|
| 1316 |
+
for i in range(10):
|
| 1317 |
+
LR = np.append(LR, LR[-1] / 2 * np.ones(10))
|
| 1318 |
+
else:
|
| 1319 |
+
LR = np.append(
|
| 1320 |
+
LR,
|
| 1321 |
+
model.learning_rate_const * np.ones(max(0, model.n_epochs - 10)))
|
| 1322 |
+
else:
|
| 1323 |
+
LR = model.learning_rate_const * np.ones(model.n_epochs)
|
| 1324 |
+
model.learning_rate = LR
|
| 1325 |
+
|
| 1326 |
+
model.batch_size = batch_size
|
| 1327 |
+
model._set_optimizer(model.learning_rate[0], momentum, weight_decay, SGD)
|
| 1328 |
+
model._set_criterion()
|
| 1329 |
+
|
| 1330 |
+
nimg = len(train_data)
|
| 1331 |
+
|
| 1332 |
+
# compute average cell diameter
|
| 1333 |
+
if diameter is None:
|
| 1334 |
+
diam_train = np.array(
|
| 1335 |
+
[utils.diameters(train_labels[k][0])[0] for k in range(len(train_labels))])
|
| 1336 |
+
diam_train_mean = diam_train[diam_train > 0].mean()
|
| 1337 |
+
model.diam_labels = diam_train_mean
|
| 1338 |
+
if rescale:
|
| 1339 |
+
diam_train[diam_train < 5] = 5.
|
| 1340 |
+
if test_data is not None:
|
| 1341 |
+
diam_test = np.array([
|
| 1342 |
+
utils.diameters(test_labels[k][0])[0]
|
| 1343 |
+
for k in range(len(test_labels))
|
| 1344 |
+
])
|
| 1345 |
+
diam_test[diam_test < 5] = 5.
|
| 1346 |
+
denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
|
| 1347 |
+
elif rescale:
|
| 1348 |
+
diam_train_mean = diameter
|
| 1349 |
+
model.diam_labels = diameter
|
| 1350 |
+
denoise_logger.info(">>>> median diameter set to = %d" % model.diam_mean)
|
| 1351 |
+
diam_train = diameter * np.ones(len(train_labels), "float32")
|
| 1352 |
+
if test_data is not None:
|
| 1353 |
+
diam_test = diameter * np.ones(len(test_labels), "float32")
|
| 1354 |
+
|
| 1355 |
+
denoise_logger.info(
|
| 1356 |
+
f">>>> mean of training label mask diameters (saved to model) {diam_train_mean:.3f}"
|
| 1357 |
+
)
|
| 1358 |
+
model.net.diam_labels.data = torch.ones(1, device=model.device) * diam_train_mean
|
| 1359 |
+
|
| 1360 |
+
nchan = train_data[0].shape[0]
|
| 1361 |
+
denoise_logger.info(">>>> training network with %d channel input <<<<" % nchan)
|
| 1362 |
+
denoise_logger.info(">>>> LR: %0.5f, batch_size: %d, weight_decay: %0.5f" %
|
| 1363 |
+
(model.learning_rate_const, model.batch_size, weight_decay))
|
| 1364 |
+
|
| 1365 |
+
if test_data is not None:
|
| 1366 |
+
denoise_logger.info(f">>>> ntrain = {nimg}, ntest = {len(test_data)}")
|
| 1367 |
+
else:
|
| 1368 |
+
denoise_logger.info(f">>>> ntrain = {nimg}")
|
| 1369 |
+
|
| 1370 |
+
tic = time.time()
|
| 1371 |
+
|
| 1372 |
+
lavg, nsum = 0, 0
|
| 1373 |
+
|
| 1374 |
+
if save_path is not None:
|
| 1375 |
+
_, file_label = os.path.split(save_path)
|
| 1376 |
+
file_path = os.path.join(save_path, "models/")
|
| 1377 |
+
|
| 1378 |
+
if not os.path.exists(file_path):
|
| 1379 |
+
os.makedirs(file_path)
|
| 1380 |
+
else:
|
| 1381 |
+
denoise_logger.warning("WARNING: no save_path given, model not saving")
|
| 1382 |
+
|
| 1383 |
+
ksave = 0
|
| 1384 |
+
|
| 1385 |
+
# get indices for each epoch for training
|
| 1386 |
+
np.random.seed(0)
|
| 1387 |
+
inds_all = np.zeros((0,), "int32")
|
| 1388 |
+
if nimg_per_epoch is None or nimg > nimg_per_epoch:
|
| 1389 |
+
nimg_per_epoch = nimg
|
| 1390 |
+
denoise_logger.info(f">>>> nimg_per_epoch = {nimg_per_epoch}")
|
| 1391 |
+
while len(inds_all) < n_epochs * nimg_per_epoch:
|
| 1392 |
+
rperm = np.random.permutation(nimg)
|
| 1393 |
+
inds_all = np.hstack((inds_all, rperm))
|
| 1394 |
+
|
| 1395 |
+
for iepoch in range(model.n_epochs):
|
| 1396 |
+
if SGD:
|
| 1397 |
+
model._set_learning_rate(model.learning_rate[iepoch])
|
| 1398 |
+
np.random.seed(iepoch)
|
| 1399 |
+
rperm = inds_all[iepoch * nimg_per_epoch:(iepoch + 1) * nimg_per_epoch]
|
| 1400 |
+
for ibatch in range(0, nimg_per_epoch, batch_size):
|
| 1401 |
+
inds = rperm[ibatch:ibatch + batch_size]
|
| 1402 |
+
imgi, lbl, scale = random_rotate_and_resize_noise(
|
| 1403 |
+
[train_data[i] for i in inds], [train_labels[i][1:] for i in inds],
|
| 1404 |
+
poisson=poisson, blur=blur, downsample=downsample,
|
| 1405 |
+
diams=diam_train[inds], diam_mean=model.diam_mean)
|
| 1406 |
+
imgi = imgi[:, :1] # keep noisy only
|
| 1407 |
+
if z_masking:
|
| 1408 |
+
nc = imgi.shape[1]
|
| 1409 |
+
nb = imgi.shape[0]
|
| 1410 |
+
ncmin = (np.random.rand(nb) > 0.25) * (np.random.randint(
|
| 1411 |
+
nc // 2 - 1, size=nb))
|
| 1412 |
+
ncmax = nc - (np.random.rand(nb) > 0.25) * (np.random.randint(
|
| 1413 |
+
nc // 2 - 1, size=nb))
|
| 1414 |
+
for b in range(nb):
|
| 1415 |
+
imgi[b, :ncmin[b]] = 0
|
| 1416 |
+
imgi[b, ncmax[b]:] = 0
|
| 1417 |
+
|
| 1418 |
+
train_loss = model._train_step(imgi, lbl)
|
| 1419 |
+
lavg += train_loss
|
| 1420 |
+
nsum += len(imgi)
|
| 1421 |
+
|
| 1422 |
+
if iepoch % 10 == 0 or iepoch == 5:
|
| 1423 |
+
lavg = lavg / nsum
|
| 1424 |
+
if test_data is not None:
|
| 1425 |
+
lavgt, nsum = 0., 0
|
| 1426 |
+
np.random.seed(42)
|
| 1427 |
+
rperm = np.arange(0, len(test_data), 1, int)
|
| 1428 |
+
for ibatch in range(0, len(test_data), batch_size):
|
| 1429 |
+
inds = rperm[ibatch:ibatch + batch_size]
|
| 1430 |
+
imgi, lbl, scale = random_rotate_and_resize_noise(
|
| 1431 |
+
[test_data[i] for i in inds],
|
| 1432 |
+
[test_labels[i][1:] for i in inds], poisson=poisson, blur=blur,
|
| 1433 |
+
downsample=downsample, diams=diam_test[inds],
|
| 1434 |
+
diam_mean=model.diam_mean)
|
| 1435 |
+
imgi = imgi[:, :1] # keep noisy only
|
| 1436 |
+
test_loss = model._test_eval(imgi, lbl)
|
| 1437 |
+
lavgt += test_loss
|
| 1438 |
+
nsum += len(imgi)
|
| 1439 |
+
|
| 1440 |
+
denoise_logger.info(
|
| 1441 |
+
"Epoch %d, Time %4.1fs, Loss %2.4f, Loss Test %2.4f, LR %2.4f" %
|
| 1442 |
+
(iepoch, time.time() - tic, lavg, lavgt / nsum,
|
| 1443 |
+
model.learning_rate[iepoch]))
|
| 1444 |
+
else:
|
| 1445 |
+
denoise_logger.info(
|
| 1446 |
+
"Epoch %d, Time %4.1fs, Loss %2.4f, LR %2.4f" %
|
| 1447 |
+
(iepoch, time.time() - tic, lavg, model.learning_rate[iepoch]))
|
| 1448 |
+
|
| 1449 |
+
lavg, nsum = 0, 0
|
| 1450 |
+
|
| 1451 |
+
if save_path is not None:
|
| 1452 |
+
if iepoch == model.n_epochs - 1 or iepoch % save_every == 1:
|
| 1453 |
+
# save model at the end
|
| 1454 |
+
if save_each: #separate files as model progresses
|
| 1455 |
+
if model_name is None:
|
| 1456 |
+
filename = "{}_{}_{}_{}".format(
|
| 1457 |
+
model.net_type, file_label,
|
| 1458 |
+
d.strftime("%Y_%m_%d_%H_%M_%S.%f"), "epoch_" + str(iepoch))
|
| 1459 |
+
else:
|
| 1460 |
+
filename = "{}_{}".format(model_name, "epoch_" + str(iepoch))
|
| 1461 |
+
else:
|
| 1462 |
+
if model_name is None:
|
| 1463 |
+
filename = "{}_{}_{}".format(model.net_type, file_label,
|
| 1464 |
+
d.strftime("%Y_%m_%d_%H_%M_%S.%f"))
|
| 1465 |
+
else:
|
| 1466 |
+
filename = model_name
|
| 1467 |
+
filename = os.path.join(file_path, filename)
|
| 1468 |
+
ksave += 1
|
| 1469 |
+
denoise_logger.info(f"saving network parameters to {filename}")
|
| 1470 |
+
model.net.save_model(filename)
|
| 1471 |
+
else:
|
| 1472 |
+
filename = save_path
|
| 1473 |
+
|
| 1474 |
+
return filename
|
models/seg_post_model/cellpose/dynamics.py
ADDED
|
@@ -0,0 +1,691 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer , Michael Rariden and Marius Pachitariu.
|
| 3 |
+
"""
|
| 4 |
+
import os
|
| 5 |
+
from scipy.ndimage import find_objects, center_of_mass, mean
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import tifffile
|
| 9 |
+
from tqdm import trange
|
| 10 |
+
import fastremap
|
| 11 |
+
|
| 12 |
+
import logging
|
| 13 |
+
|
| 14 |
+
dynamics_logger = logging.getLogger(__name__)
|
| 15 |
+
|
| 16 |
+
from . import utils
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import torch.nn.functional as F
|
| 20 |
+
|
| 21 |
+
def _extend_centers_gpu(neighbors, meds, isneighbor, shape, n_iter=200,
|
| 22 |
+
device=torch.device("cpu")):
|
| 23 |
+
"""Runs diffusion on GPU to generate flows for training images or quality control.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
neighbors (torch.Tensor): 9 x pixels in masks.
|
| 27 |
+
meds (torch.Tensor): Mask centers.
|
| 28 |
+
isneighbor (torch.Tensor): Valid neighbor boolean 9 x pixels.
|
| 29 |
+
shape (tuple): Shape of the tensor.
|
| 30 |
+
n_iter (int, optional): Number of iterations. Defaults to 200.
|
| 31 |
+
device (torch.device, optional): Device to run the computation on. Defaults to torch.device("cpu").
|
| 32 |
+
|
| 33 |
+
Returns:
|
| 34 |
+
torch.Tensor: Generated flows.
|
| 35 |
+
|
| 36 |
+
"""
|
| 37 |
+
if torch.prod(torch.tensor(shape)) > 4e7 or device.type == "mps":
|
| 38 |
+
T = torch.zeros(shape, dtype=torch.float, device=device)
|
| 39 |
+
else:
|
| 40 |
+
T = torch.zeros(shape, dtype=torch.double, device=device)
|
| 41 |
+
|
| 42 |
+
for i in range(n_iter):
|
| 43 |
+
T[tuple(meds.T)] += 1
|
| 44 |
+
Tneigh = T[tuple(neighbors)]
|
| 45 |
+
Tneigh *= isneighbor
|
| 46 |
+
T[tuple(neighbors[:, 0])] = Tneigh.mean(axis=0)
|
| 47 |
+
del meds, isneighbor, Tneigh
|
| 48 |
+
|
| 49 |
+
if T.ndim == 2:
|
| 50 |
+
grads = T[neighbors[0, [2, 1, 4, 3]], neighbors[1, [2, 1, 4, 3]]]
|
| 51 |
+
del neighbors
|
| 52 |
+
dy = grads[0] - grads[1]
|
| 53 |
+
dx = grads[2] - grads[3]
|
| 54 |
+
del grads
|
| 55 |
+
mu_torch = np.stack((dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
|
| 56 |
+
else:
|
| 57 |
+
grads = T[tuple(neighbors[:, 1:])]
|
| 58 |
+
del neighbors
|
| 59 |
+
dz = grads[0] - grads[1]
|
| 60 |
+
dy = grads[2] - grads[3]
|
| 61 |
+
dx = grads[4] - grads[5]
|
| 62 |
+
del grads
|
| 63 |
+
mu_torch = np.stack(
|
| 64 |
+
(dz.cpu().squeeze(0), dy.cpu().squeeze(0), dx.cpu().squeeze(0)), axis=-2)
|
| 65 |
+
return mu_torch
|
| 66 |
+
|
| 67 |
+
def center_of_mass(mask):
|
| 68 |
+
yi, xi = np.nonzero(mask)
|
| 69 |
+
ymean = int(np.round(yi.sum() / len(yi)))
|
| 70 |
+
xmean = int(np.round(xi.sum() / len(xi)))
|
| 71 |
+
if not ((yi==ymean) * (xi==xmean)).sum():
|
| 72 |
+
# center is closest point to (ymean, xmean) within mask
|
| 73 |
+
imin = ((xi - xmean)**2 + (yi - ymean)**2).argmin()
|
| 74 |
+
ymean = yi[imin]
|
| 75 |
+
xmean = xi[imin]
|
| 76 |
+
|
| 77 |
+
return ymean, xmean
|
| 78 |
+
|
| 79 |
+
def get_centers(masks, slices):
|
| 80 |
+
centers = [center_of_mass(masks[slices[i]]==(i+1)) for i in range(len(slices))]
|
| 81 |
+
centers = np.array([np.array([centers[i][0] + slices[i][0].start, centers[i][1] + slices[i][1].start])
|
| 82 |
+
for i in range(len(slices))])
|
| 83 |
+
exts = np.array([(slc[0].stop - slc[0].start) + (slc[1].stop - slc[1].start) + 2 for slc in slices])
|
| 84 |
+
return centers, exts
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def masks_to_flows_gpu(masks, device=torch.device("cpu"), niter=None):
|
| 88 |
+
"""Convert masks to flows using diffusion from center pixel.
|
| 89 |
+
|
| 90 |
+
Center of masks where diffusion starts is defined by pixel closest to median within the mask.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
|
| 94 |
+
device (torch.device, optional): The device to run the computation on. Defaults to torch.device("cpu").
|
| 95 |
+
niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
|
| 96 |
+
|
| 97 |
+
Returns:
|
| 98 |
+
np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
Returns:
|
| 102 |
+
A tuple containing (mu, meds_p). mu is float 3D or 4D array of flows in (Z)XY.
|
| 103 |
+
meds_p are cell centers.
|
| 104 |
+
"""
|
| 105 |
+
if device is None:
|
| 106 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
|
| 107 |
+
|
| 108 |
+
if masks.max() > 0:
|
| 109 |
+
Ly0, Lx0 = masks.shape
|
| 110 |
+
Ly, Lx = Ly0 + 2, Lx0 + 2
|
| 111 |
+
|
| 112 |
+
masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
|
| 113 |
+
masks_padded = F.pad(masks_padded, (1, 1, 1, 1))
|
| 114 |
+
shape = masks_padded.shape
|
| 115 |
+
|
| 116 |
+
### get mask pixel neighbors
|
| 117 |
+
y, x = torch.nonzero(masks_padded, as_tuple=True)
|
| 118 |
+
y = y.int()
|
| 119 |
+
x = x.int()
|
| 120 |
+
neighbors = torch.zeros((2, 9, y.shape[0]), dtype=torch.int, device=device)
|
| 121 |
+
yxi = [[0, -1, 1, 0, 0, -1, -1, 1, 1], [0, 0, 0, -1, 1, -1, 1, -1, 1]]
|
| 122 |
+
for i in range(9):
|
| 123 |
+
neighbors[0, i] = y + yxi[0][i]
|
| 124 |
+
neighbors[1, i] = x + yxi[1][i]
|
| 125 |
+
isneighbor = torch.ones((9, y.shape[0]), dtype=torch.bool, device=device)
|
| 126 |
+
m0 = masks_padded[neighbors[0, 0], neighbors[1, 0]]
|
| 127 |
+
for i in range(1, 9):
|
| 128 |
+
isneighbor[i] = masks_padded[neighbors[0, i], neighbors[1, i]] == m0
|
| 129 |
+
del m0, masks_padded
|
| 130 |
+
|
| 131 |
+
### get center-of-mass within cell
|
| 132 |
+
slices = find_objects(masks)
|
| 133 |
+
centers, ext = get_centers(masks, slices)
|
| 134 |
+
meds_p = torch.from_numpy(centers).to(device).long()
|
| 135 |
+
meds_p += 1 # for padding
|
| 136 |
+
|
| 137 |
+
### run diffusion
|
| 138 |
+
n_iter = 2 * ext.max() if niter is None else niter
|
| 139 |
+
mu = _extend_centers_gpu(neighbors, meds_p, isneighbor, shape, n_iter=n_iter,
|
| 140 |
+
device=device)
|
| 141 |
+
mu = mu.astype("float64")
|
| 142 |
+
|
| 143 |
+
# new normalization
|
| 144 |
+
mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
|
| 145 |
+
|
| 146 |
+
# put into original image
|
| 147 |
+
mu0 = np.zeros((2, Ly0, Lx0))
|
| 148 |
+
mu0[:, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
|
| 149 |
+
else:
|
| 150 |
+
# no masks, return empty flows
|
| 151 |
+
mu0 = np.zeros((2, masks.shape[0], masks.shape[1]))
|
| 152 |
+
return mu0
|
| 153 |
+
|
| 154 |
+
def masks_to_flows_gpu_3d(masks, device=None, niter=None):
|
| 155 |
+
"""Convert masks to flows using diffusion from center pixel.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
masks (int, 2D or 3D array): Labelled masks. 0=NO masks; 1,2,...=mask labels.
|
| 159 |
+
device (torch.device, optional): The device to run the computation on. Defaults to None.
|
| 160 |
+
niter (int, optional): Number of iterations for the diffusion process. Defaults to None.
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
np.ndarray: A 4D array representing the flows for each pixel in Z, X, and Y.
|
| 164 |
+
|
| 165 |
+
"""
|
| 166 |
+
if device is None:
|
| 167 |
+
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('mps') if torch.backends.mps.is_available() else None
|
| 168 |
+
|
| 169 |
+
Lz0, Ly0, Lx0 = masks.shape
|
| 170 |
+
Lz, Ly, Lx = Lz0 + 2, Ly0 + 2, Lx0 + 2
|
| 171 |
+
|
| 172 |
+
masks_padded = torch.from_numpy(masks.astype("int64")).to(device)
|
| 173 |
+
masks_padded = F.pad(masks_padded, (1, 1, 1, 1, 1, 1))
|
| 174 |
+
|
| 175 |
+
# get mask pixel neighbors
|
| 176 |
+
z, y, x = torch.nonzero(masks_padded).T
|
| 177 |
+
neighborsZ = torch.stack((z, z + 1, z - 1, z, z, z, z))
|
| 178 |
+
neighborsY = torch.stack((y, y, y, y + 1, y - 1, y, y), axis=0)
|
| 179 |
+
neighborsX = torch.stack((x, x, x, x, x, x + 1, x - 1), axis=0)
|
| 180 |
+
|
| 181 |
+
neighbors = torch.stack((neighborsZ, neighborsY, neighborsX), axis=0)
|
| 182 |
+
|
| 183 |
+
# get mask centers
|
| 184 |
+
slices = find_objects(masks)
|
| 185 |
+
|
| 186 |
+
centers = np.zeros((masks.max(), 3), "int")
|
| 187 |
+
for i, si in enumerate(slices):
|
| 188 |
+
if si is not None:
|
| 189 |
+
sz, sy, sx = si
|
| 190 |
+
zi, yi, xi = np.nonzero(masks[sz, sy, sx] == (i + 1))
|
| 191 |
+
zi = zi.astype(np.int32) + 1 # add padding
|
| 192 |
+
yi = yi.astype(np.int32) + 1 # add padding
|
| 193 |
+
xi = xi.astype(np.int32) + 1 # add padding
|
| 194 |
+
zmed = np.mean(zi)
|
| 195 |
+
ymed = np.mean(yi)
|
| 196 |
+
xmed = np.mean(xi)
|
| 197 |
+
imin = np.argmin((zi - zmed)**2 + (xi - xmed)**2 + (yi - ymed)**2)
|
| 198 |
+
zmed = zi[imin]
|
| 199 |
+
ymed = yi[imin]
|
| 200 |
+
xmed = xi[imin]
|
| 201 |
+
centers[i, 0] = zmed + sz.start
|
| 202 |
+
centers[i, 1] = ymed + sy.start
|
| 203 |
+
centers[i, 2] = xmed + sx.start
|
| 204 |
+
|
| 205 |
+
# get neighbor validator (not all neighbors are in same mask)
|
| 206 |
+
neighbor_masks = masks_padded[tuple(neighbors)]
|
| 207 |
+
isneighbor = neighbor_masks == neighbor_masks[0]
|
| 208 |
+
ext = np.array(
|
| 209 |
+
[[sz.stop - sz.start + 1, sy.stop - sy.start + 1, sx.stop - sx.start + 1]
|
| 210 |
+
for sz, sy, sx in slices])
|
| 211 |
+
n_iter = 6 * (ext.sum(axis=1)).max() if niter is None else niter
|
| 212 |
+
|
| 213 |
+
# run diffusion
|
| 214 |
+
shape = masks_padded.shape
|
| 215 |
+
mu = _extend_centers_gpu(neighbors, centers, isneighbor, shape, n_iter=n_iter,
|
| 216 |
+
device=device)
|
| 217 |
+
# normalize
|
| 218 |
+
mu /= (1e-60 + (mu**2).sum(axis=0)**0.5)
|
| 219 |
+
|
| 220 |
+
# put into original image
|
| 221 |
+
mu0 = np.zeros((3, Lz0, Ly0, Lx0))
|
| 222 |
+
mu0[:, z.cpu().numpy() - 1, y.cpu().numpy() - 1, x.cpu().numpy() - 1] = mu
|
| 223 |
+
return mu0
|
| 224 |
+
|
| 225 |
+
def labels_to_flows(labels, files=None, device=None, redo_flows=False, niter=None,
|
| 226 |
+
return_flows=True):
|
| 227 |
+
"""Converts labels (list of masks or flows) to flows for training model.
|
| 228 |
+
|
| 229 |
+
Args:
|
| 230 |
+
labels (list of ND-arrays): The labels to convert. labels[k] can be 2D or 3D. If [3 x Ly x Lx],
|
| 231 |
+
it is assumed that flows were precomputed. Otherwise, labels[k][0] or labels[k] (if 2D)
|
| 232 |
+
is used to create flows and cell probabilities.
|
| 233 |
+
files (list of str, optional): The files to save the flows to. If provided, flows are saved to
|
| 234 |
+
files to be reused. Defaults to None.
|
| 235 |
+
device (str, optional): The device to use for computation. Defaults to None.
|
| 236 |
+
redo_flows (bool, optional): Whether to recompute the flows. Defaults to False.
|
| 237 |
+
niter (int, optional): The number of iterations for computing flows. Defaults to None.
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
list of [4 x Ly x Lx] arrays: The flows for training the model. flows[k][0] is labels[k],
|
| 241 |
+
flows[k][1] is cell distance transform, flows[k][2] is Y flow, flows[k][3] is X flow,
|
| 242 |
+
and flows[k][4] is heat distribution.
|
| 243 |
+
"""
|
| 244 |
+
nimg = len(labels)
|
| 245 |
+
if labels[0].ndim < 3:
|
| 246 |
+
labels = [labels[n][np.newaxis, :, :] for n in range(nimg)]
|
| 247 |
+
|
| 248 |
+
flows = []
|
| 249 |
+
# flows need to be recomputed
|
| 250 |
+
if labels[0].shape[0] == 1 or labels[0].ndim < 3 or redo_flows:
|
| 251 |
+
dynamics_logger.info("computing flows for labels")
|
| 252 |
+
|
| 253 |
+
# compute flows; labels are fixed here to be unique, so they need to be passed back
|
| 254 |
+
# make sure labels are unique!
|
| 255 |
+
labels = [fastremap.renumber(label, in_place=True)[0] for label in labels]
|
| 256 |
+
iterator = trange if nimg > 1 else range
|
| 257 |
+
for n in iterator(nimg):
|
| 258 |
+
labels[n][0] = fastremap.renumber(labels[n][0], in_place=True)[0]
|
| 259 |
+
vecn = masks_to_flows_gpu(labels[n][0].astype(int), device=device, niter=niter)
|
| 260 |
+
|
| 261 |
+
# concatenate labels, distance transform, vector flows, heat (boundary and mask are computed in augmentations)
|
| 262 |
+
flow = np.concatenate((labels[n], labels[n] > 0.5, vecn),
|
| 263 |
+
axis=0).astype(np.float32)
|
| 264 |
+
if files is not None:
|
| 265 |
+
file_name = os.path.splitext(files[n])[0]
|
| 266 |
+
tifffile.imwrite(file_name + "_flows.tif", flow)
|
| 267 |
+
if return_flows:
|
| 268 |
+
flows.append(flow)
|
| 269 |
+
else:
|
| 270 |
+
dynamics_logger.info("flows precomputed")
|
| 271 |
+
if return_flows:
|
| 272 |
+
flows = [labels[n].astype(np.float32) for n in range(nimg)]
|
| 273 |
+
return flows
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def flow_error(maski, dP_net, device=None):
|
| 277 |
+
"""Error in flows from predicted masks vs flows predicted by network run on image.
|
| 278 |
+
|
| 279 |
+
This function serves to benchmark the quality of masks. It works as follows:
|
| 280 |
+
1. The predicted masks are used to create a flow diagram.
|
| 281 |
+
2. The mask-flows are compared to the flows that the network predicted.
|
| 282 |
+
|
| 283 |
+
If there is a discrepancy between the flows, it suggests that the mask is incorrect.
|
| 284 |
+
Masks with flow_errors greater than 0.4 are discarded by default. This setting can be
|
| 285 |
+
changed in Cellpose.eval or CellposeModel.eval.
|
| 286 |
+
|
| 287 |
+
Args:
|
| 288 |
+
maski (np.ndarray, int): Masks produced from running dynamics on dP_net, where 0=NO masks; 1,2... are mask labels.
|
| 289 |
+
dP_net (np.ndarray, float): ND flows where dP_net.shape[1:] = maski.shape.
|
| 290 |
+
|
| 291 |
+
Returns:
|
| 292 |
+
A tuple containing (flow_errors, dP_masks): flow_errors (np.ndarray, float): Mean squared error between predicted flows and flows from masks;
|
| 293 |
+
dP_masks (np.ndarray, float): ND flows produced from the predicted masks.
|
| 294 |
+
"""
|
| 295 |
+
if dP_net.shape[1:] != maski.shape:
|
| 296 |
+
print("ERROR: net flow is not same size as predicted masks")
|
| 297 |
+
return
|
| 298 |
+
|
| 299 |
+
# flows predicted from estimated masks
|
| 300 |
+
dP_masks = masks_to_flows_gpu(maski, device=device)
|
| 301 |
+
# difference between predicted flows vs mask flows
|
| 302 |
+
flow_errors = np.zeros(maski.max())
|
| 303 |
+
for i in range(dP_masks.shape[0]):
|
| 304 |
+
flow_errors += mean((dP_masks[i] - dP_net[i] / 5.)**2, maski,
|
| 305 |
+
index=np.arange(1,
|
| 306 |
+
maski.max() + 1))
|
| 307 |
+
|
| 308 |
+
return flow_errors, dP_masks
|
| 309 |
+
|
| 310 |
+
|
| 311 |
+
def steps_interp(dP, inds, niter, device=torch.device("cpu")):
|
| 312 |
+
""" Run dynamics of pixels to recover masks in 2D/3D, with interpolation between pixel values.
|
| 313 |
+
|
| 314 |
+
Euler integration of dynamics dP for niter steps.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
p (numpy.ndarray): Array of shape (n_points, 2 or 3) representing the initial pixel locations.
|
| 318 |
+
dP (numpy.ndarray): Array of shape (2, Ly, Lx) or (3, Lz, Ly, Lx) representing the flow field.
|
| 319 |
+
niter (int): Number of iterations to perform.
|
| 320 |
+
device (torch.device, optional): Device to use for computation. Defaults to None.
|
| 321 |
+
|
| 322 |
+
Returns:
|
| 323 |
+
numpy.ndarray: Array of shape (n_points, 2) or (n_points, 3) representing the final pixel locations.
|
| 324 |
+
|
| 325 |
+
Raises:
|
| 326 |
+
None
|
| 327 |
+
|
| 328 |
+
"""
|
| 329 |
+
|
| 330 |
+
shape = dP.shape[1:]
|
| 331 |
+
ndim = len(shape)
|
| 332 |
+
|
| 333 |
+
pt = torch.zeros((*[1]*ndim, len(inds[0]), ndim), dtype=torch.float32, device=device)
|
| 334 |
+
im = torch.zeros((1, ndim, *shape), dtype=torch.float32, device=device)
|
| 335 |
+
# Y and X dimensions, flipped X-1, Y-1
|
| 336 |
+
# pt is [1 1 1 3 n_points]
|
| 337 |
+
for n in range(ndim):
|
| 338 |
+
if ndim==3:
|
| 339 |
+
pt[0, 0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32)
|
| 340 |
+
else:
|
| 341 |
+
pt[0, 0, :, ndim - n - 1] = torch.from_numpy(inds[n]).to(device, dtype=torch.float32)
|
| 342 |
+
im[0, ndim - n - 1] = torch.from_numpy(dP[n]).to(device, dtype=torch.float32)
|
| 343 |
+
shape = np.array(shape)[::-1].astype("float") - 1
|
| 344 |
+
|
| 345 |
+
# normalize pt between 0 and 1, normalize the flow
|
| 346 |
+
for k in range(ndim):
|
| 347 |
+
im[:, k] *= 2. / shape[k]
|
| 348 |
+
pt[..., k] /= shape[k]
|
| 349 |
+
|
| 350 |
+
# normalize to between -1 and 1
|
| 351 |
+
pt *= 2
|
| 352 |
+
pt -= 1
|
| 353 |
+
|
| 354 |
+
# dynamics
|
| 355 |
+
for t in range(niter):
|
| 356 |
+
dPt = torch.nn.functional.grid_sample(im, pt, align_corners=False)
|
| 357 |
+
for k in range(ndim): #clamp the final pixel locations
|
| 358 |
+
pt[..., k] = torch.clamp(pt[..., k] + dPt[:, k], -1., 1.)
|
| 359 |
+
|
| 360 |
+
#undo the normalization from before, reverse order of operations
|
| 361 |
+
pt += 1
|
| 362 |
+
pt *= 0.5
|
| 363 |
+
for k in range(ndim):
|
| 364 |
+
pt[..., k] *= shape[k]
|
| 365 |
+
|
| 366 |
+
if ndim==3:
|
| 367 |
+
pt = pt[..., [2, 1, 0]].squeeze()
|
| 368 |
+
pt = pt.unsqueeze(0) if pt.ndim==1 else pt
|
| 369 |
+
return pt.T
|
| 370 |
+
else:
|
| 371 |
+
pt = pt[..., [1, 0]].squeeze()
|
| 372 |
+
pt = pt.unsqueeze(0) if pt.ndim==1 else pt
|
| 373 |
+
return pt.T
|
| 374 |
+
|
| 375 |
+
def follow_flows(dP, inds, niter=200, device=torch.device("cpu")):
|
| 376 |
+
""" Run dynamics to recover masks in 2D or 3D.
|
| 377 |
+
|
| 378 |
+
Pixels are represented as a meshgrid. Only pixels with non-zero cell-probability
|
| 379 |
+
are used (as defined by inds).
|
| 380 |
+
|
| 381 |
+
Args:
|
| 382 |
+
dP (np.ndarray): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
|
| 383 |
+
mask (np.ndarray, optional): Pixel mask to seed masks. Useful when flows have low magnitudes.
|
| 384 |
+
niter (int, optional): Number of iterations of dynamics to run. Default is 200.
|
| 385 |
+
interp (bool, optional): Interpolate during 2D dynamics (not available in 3D). Default is True.
|
| 386 |
+
device (torch.device, optional): Device to use for computation. Default is None.
|
| 387 |
+
|
| 388 |
+
Returns:
|
| 389 |
+
A tuple containing (p, inds): p (np.ndarray): Final locations of each pixel after dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx];
|
| 390 |
+
inds (np.ndarray): Indices of pixels used for dynamics; [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
|
| 391 |
+
"""
|
| 392 |
+
shape = np.array(dP.shape[1:]).astype(np.int32)
|
| 393 |
+
ndim = len(inds)
|
| 394 |
+
|
| 395 |
+
p = steps_interp(dP, inds, niter, device=device)
|
| 396 |
+
|
| 397 |
+
return p
|
| 398 |
+
|
| 399 |
+
|
| 400 |
+
def remove_bad_flow_masks(masks, flows, threshold=0.4, device=torch.device("cpu")):
|
| 401 |
+
"""Remove masks which have inconsistent flows.
|
| 402 |
+
|
| 403 |
+
Uses metrics.flow_error to compute flows from predicted masks
|
| 404 |
+
and compare flows to predicted flows from the network. Discards
|
| 405 |
+
masks with flow errors greater than the threshold.
|
| 406 |
+
|
| 407 |
+
Args:
|
| 408 |
+
masks (int, 2D or 3D array): Labelled masks, 0=NO masks; 1,2,...=mask labels,
|
| 409 |
+
size [Ly x Lx] or [Lz x Ly x Lx].
|
| 410 |
+
flows (float, 3D or 4D array): Flows [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
|
| 411 |
+
threshold (float, optional): Masks with flow error greater than threshold are discarded.
|
| 412 |
+
Default is 0.4.
|
| 413 |
+
|
| 414 |
+
Returns:
|
| 415 |
+
masks (int, 2D or 3D array): Masks with inconsistent flow masks removed,
|
| 416 |
+
0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
|
| 417 |
+
"""
|
| 418 |
+
device0 = device
|
| 419 |
+
if masks.size > 10000 * 10000 and (device is not None and device.type == "cuda"):
|
| 420 |
+
|
| 421 |
+
major_version, minor_version = torch.__version__.split(".")[:2]
|
| 422 |
+
torch.cuda.empty_cache()
|
| 423 |
+
if major_version == "1" and int(minor_version) < 10:
|
| 424 |
+
# for PyTorch version lower than 1.10
|
| 425 |
+
def mem_info():
|
| 426 |
+
total_mem = torch.cuda.get_device_properties(device0.index).total_memory
|
| 427 |
+
used_mem = torch.cuda.memory_allocated(device0.index)
|
| 428 |
+
free_mem = total_mem - used_mem
|
| 429 |
+
return total_mem, free_mem
|
| 430 |
+
else:
|
| 431 |
+
# for PyTorch version 1.10 and above
|
| 432 |
+
def mem_info():
|
| 433 |
+
free_mem, total_mem = torch.cuda.mem_get_info(device0.index)
|
| 434 |
+
return total_mem, free_mem
|
| 435 |
+
total_mem, free_mem = mem_info()
|
| 436 |
+
if masks.size * 32 > free_mem:
|
| 437 |
+
dynamics_logger.warning(
|
| 438 |
+
"WARNING: image is very large, not using gpu to compute flows from masks for QC step flow_threshold"
|
| 439 |
+
)
|
| 440 |
+
dynamics_logger.info("turn off QC step with flow_threshold=0 if too slow")
|
| 441 |
+
device0 = torch.device("cpu")
|
| 442 |
+
|
| 443 |
+
merrors, _ = flow_error(masks, flows, device0)
|
| 444 |
+
badi = 1 + (merrors > threshold).nonzero()[0]
|
| 445 |
+
masks[np.isin(masks, badi)] = 0
|
| 446 |
+
return masks
|
| 447 |
+
|
| 448 |
+
|
| 449 |
+
def max_pool1d(h, kernel_size=5, axis=1, out=None):
|
| 450 |
+
""" memory efficient max_pool thanks to Mark Kittisopikul
|
| 451 |
+
|
| 452 |
+
for stride=1, padding=kernel_size//2, requires odd kernel_size >= 3
|
| 453 |
+
|
| 454 |
+
"""
|
| 455 |
+
if out is None:
|
| 456 |
+
out = h.clone()
|
| 457 |
+
else:
|
| 458 |
+
out.copy_(h)
|
| 459 |
+
|
| 460 |
+
nd = h.shape[axis]
|
| 461 |
+
k0 = kernel_size // 2
|
| 462 |
+
for d in range(-k0, k0+1):
|
| 463 |
+
if axis==1:
|
| 464 |
+
mv = out[:, max(-d,0):min(nd-d,nd)]
|
| 465 |
+
hv = h[:, max(d,0):min(nd+d,nd)]
|
| 466 |
+
elif axis==2:
|
| 467 |
+
mv = out[:, :, max(-d,0):min(nd-d,nd)]
|
| 468 |
+
hv = h[:, :, max(d,0):min(nd+d,nd)]
|
| 469 |
+
elif axis==3:
|
| 470 |
+
mv = out[:, :, :, max(-d,0):min(nd-d,nd)]
|
| 471 |
+
hv = h[:, :, :, max(d,0):min(nd+d,nd)]
|
| 472 |
+
torch.maximum(mv, hv, out=mv)
|
| 473 |
+
return out
|
| 474 |
+
|
| 475 |
+
def max_pool_nd(h, kernel_size=5):
|
| 476 |
+
""" memory efficient max_pool in 2d or 3d """
|
| 477 |
+
ndim = h.ndim - 1
|
| 478 |
+
hmax = max_pool1d(h, kernel_size=kernel_size, axis=1)
|
| 479 |
+
hmax2 = max_pool1d(hmax, kernel_size=kernel_size, axis=2)
|
| 480 |
+
if ndim==2:
|
| 481 |
+
del hmax
|
| 482 |
+
return hmax2
|
| 483 |
+
else:
|
| 484 |
+
hmax = max_pool1d(hmax2, kernel_size=kernel_size, axis=3, out=hmax)
|
| 485 |
+
del hmax2
|
| 486 |
+
return hmax
|
| 487 |
+
|
| 488 |
+
def get_masks_torch(pt, inds, shape0, rpad=20, max_size_fraction=0.4):
|
| 489 |
+
"""Create masks using pixel convergence after running dynamics.
|
| 490 |
+
|
| 491 |
+
Makes a histogram of final pixel locations p, initializes masks
|
| 492 |
+
at peaks of histogram and extends the masks from the peaks so that
|
| 493 |
+
they include all pixels with more than 2 final pixels p. Discards
|
| 494 |
+
masks with flow errors greater than the threshold.
|
| 495 |
+
|
| 496 |
+
Parameters:
|
| 497 |
+
p (float32, 3D or 4D array): Final locations of each pixel after dynamics,
|
| 498 |
+
size [axis x Ly x Lx] or [axis x Lz x Ly x Lx].
|
| 499 |
+
iscell (bool, 2D or 3D array): If iscell is not None, set pixels that are
|
| 500 |
+
iscell False to stay in their original location.
|
| 501 |
+
rpad (int, optional): Histogram edge padding. Default is 20.
|
| 502 |
+
max_size_fraction (float, optional): Masks larger than max_size_fraction of
|
| 503 |
+
total image size are removed. Default is 0.4.
|
| 504 |
+
|
| 505 |
+
Returns:
|
| 506 |
+
M0 (int, 2D or 3D array): Masks with inconsistent flow masks removed,
|
| 507 |
+
0=NO masks; 1,2,...=mask labels, size [Ly x Lx] or [Lz x Ly x Lx].
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
ndim = len(shape0)
|
| 511 |
+
device = pt.device
|
| 512 |
+
|
| 513 |
+
rpad = 20
|
| 514 |
+
pt += rpad
|
| 515 |
+
pt = torch.clamp(pt, min=0)
|
| 516 |
+
for i in range(len(pt)):
|
| 517 |
+
pt[i] = torch.clamp(pt[i], max=shape0[i]+rpad-1)
|
| 518 |
+
|
| 519 |
+
# # add extra padding to make divisible by 5
|
| 520 |
+
# shape = tuple((np.ceil((shape0 + 2*rpad)/5) * 5).astype(int))
|
| 521 |
+
shape = tuple(np.array(shape0) + 2*rpad)
|
| 522 |
+
|
| 523 |
+
# sparse coo torch
|
| 524 |
+
coo = torch.sparse_coo_tensor(pt, torch.ones(pt.shape[1], device=pt.device, dtype=torch.int),
|
| 525 |
+
shape)
|
| 526 |
+
h1 = coo.to_dense()
|
| 527 |
+
del coo
|
| 528 |
+
|
| 529 |
+
hmax1 = max_pool_nd(h1.unsqueeze(0), kernel_size=5)
|
| 530 |
+
hmax1 = hmax1.squeeze()
|
| 531 |
+
seeds1 = torch.nonzero((h1 - hmax1 > -1e-6) * (h1 > 10))
|
| 532 |
+
del hmax1
|
| 533 |
+
if len(seeds1) == 0:
|
| 534 |
+
dynamics_logger.warning("no seeds found in get_masks_torch - no masks found.")
|
| 535 |
+
return np.zeros(shape0, dtype="uint16")
|
| 536 |
+
|
| 537 |
+
npts = h1[tuple(seeds1.T)]
|
| 538 |
+
isort1 = npts.argsort()
|
| 539 |
+
seeds1 = seeds1[isort1]
|
| 540 |
+
|
| 541 |
+
n_seeds = len(seeds1)
|
| 542 |
+
h_slc = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
|
| 543 |
+
for k in range(n_seeds):
|
| 544 |
+
slc = tuple([slice(seeds1[k][j]-5, seeds1[k][j]+6) for j in range(ndim)])
|
| 545 |
+
h_slc[k] = h1[slc]
|
| 546 |
+
del h1
|
| 547 |
+
seed_masks = torch.zeros((n_seeds, *[11]*ndim), device=seeds1.device)
|
| 548 |
+
if ndim==2:
|
| 549 |
+
seed_masks[:,5,5] = 1
|
| 550 |
+
else:
|
| 551 |
+
seed_masks[:,5,5,5] = 1
|
| 552 |
+
|
| 553 |
+
for iter in range(5):
|
| 554 |
+
# extend
|
| 555 |
+
seed_masks = max_pool_nd(seed_masks, kernel_size=3)
|
| 556 |
+
seed_masks *= h_slc > 2
|
| 557 |
+
del h_slc
|
| 558 |
+
seeds_new = [tuple((torch.nonzero(seed_masks[k]) + seeds1[k] - 5).T)
|
| 559 |
+
for k in range(n_seeds)]
|
| 560 |
+
del seed_masks
|
| 561 |
+
|
| 562 |
+
dtype = torch.int32 if n_seeds < 2**16 else torch.int64
|
| 563 |
+
M1 = torch.zeros(shape, dtype=dtype, device=device)
|
| 564 |
+
for k in range(n_seeds):
|
| 565 |
+
M1[seeds_new[k]] = 1 + k
|
| 566 |
+
|
| 567 |
+
M1 = M1[tuple(pt)]
|
| 568 |
+
M1 = M1.cpu().numpy()
|
| 569 |
+
|
| 570 |
+
dtype = "uint16" if n_seeds < 2**16 else "uint32"
|
| 571 |
+
M0 = np.zeros(shape0, dtype=dtype)
|
| 572 |
+
M0[inds] = M1
|
| 573 |
+
|
| 574 |
+
# remove big masks
|
| 575 |
+
uniq, counts = fastremap.unique(M0, return_counts=True)
|
| 576 |
+
big = np.prod(shape0) * max_size_fraction
|
| 577 |
+
bigc = uniq[counts > big]
|
| 578 |
+
if len(bigc) > 0 and (len(bigc) > 1 or bigc[0] != 0):
|
| 579 |
+
M0 = fastremap.mask(M0, bigc)
|
| 580 |
+
fastremap.renumber(M0, in_place=True) #convenient to guarantee non-skipped labels
|
| 581 |
+
M0 = M0.reshape(tuple(shape0))
|
| 582 |
+
|
| 583 |
+
#print(f"mem used: {torch.cuda.memory_allocated()/1e9:.3f} gb, max mem used: {torch.cuda.max_memory_allocated()/1e9:.3f} gb")
|
| 584 |
+
return M0
|
| 585 |
+
|
| 586 |
+
|
| 587 |
+
def resize_and_compute_masks(dP, cellprob, niter=200, cellprob_threshold=0.0,
|
| 588 |
+
flow_threshold=0.4, do_3D=False, min_size=15,
|
| 589 |
+
max_size_fraction=0.4, resize=None, device=torch.device("cpu")):
|
| 590 |
+
"""Compute masks using dynamics from dP and cellprob, and resizes masks if resize is not None.
|
| 591 |
+
|
| 592 |
+
Args:
|
| 593 |
+
dP (numpy.ndarray): The dynamics flow field array.
|
| 594 |
+
cellprob (numpy.ndarray): The cell probability array.
|
| 595 |
+
p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None
|
| 596 |
+
niter (int, optional): The number of iterations for mask computation. Defaults to 200.
|
| 597 |
+
cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0.
|
| 598 |
+
flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4.
|
| 599 |
+
interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
|
| 600 |
+
do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
|
| 601 |
+
min_size (int, optional): The minimum size of the masks. Defaults to 15.
|
| 602 |
+
max_size_fraction (float, optional): Masks larger than max_size_fraction of
|
| 603 |
+
total image size are removed. Default is 0.4.
|
| 604 |
+
resize (tuple, optional): The desired size for resizing the masks. Defaults to None.
|
| 605 |
+
device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu").
|
| 606 |
+
|
| 607 |
+
Returns:
|
| 608 |
+
tuple: A tuple containing the computed masks and the final pixel locations.
|
| 609 |
+
"""
|
| 610 |
+
mask = compute_masks(dP, cellprob, niter=niter,
|
| 611 |
+
cellprob_threshold=cellprob_threshold,
|
| 612 |
+
flow_threshold=flow_threshold, do_3D=do_3D,
|
| 613 |
+
max_size_fraction=max_size_fraction,
|
| 614 |
+
device=device)
|
| 615 |
+
|
| 616 |
+
if resize is not None:
|
| 617 |
+
dynamics_logger.warning("Resizing is depricated in v4.0.1+")
|
| 618 |
+
|
| 619 |
+
mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
|
| 620 |
+
|
| 621 |
+
return mask
|
| 622 |
+
|
| 623 |
+
|
| 624 |
+
def compute_masks(dP, cellprob, p=None, niter=200, cellprob_threshold=0.0,
|
| 625 |
+
flow_threshold=0.4, do_3D=False, min_size=-1,
|
| 626 |
+
max_size_fraction=0.4, device=torch.device("cpu")):
|
| 627 |
+
"""Compute masks using dynamics from dP and cellprob.
|
| 628 |
+
|
| 629 |
+
Args:
|
| 630 |
+
dP (numpy.ndarray): The dynamics flow field array.
|
| 631 |
+
cellprob (numpy.ndarray): The cell probability array.
|
| 632 |
+
p (numpy.ndarray, optional): The pixels on which to run dynamics. Defaults to None
|
| 633 |
+
niter (int, optional): The number of iterations for mask computation. Defaults to 200.
|
| 634 |
+
cellprob_threshold (float, optional): The threshold for cell probability. Defaults to 0.0.
|
| 635 |
+
flow_threshold (float, optional): The threshold for quality control metrics. Defaults to 0.4.
|
| 636 |
+
interp (bool, optional): Whether to interpolate during dynamics computation. Defaults to True.
|
| 637 |
+
do_3D (bool, optional): Whether to perform mask computation in 3D. Defaults to False.
|
| 638 |
+
min_size (int, optional): The minimum size of the masks. Defaults to 15.
|
| 639 |
+
max_size_fraction (float, optional): Masks larger than max_size_fraction of
|
| 640 |
+
total image size are removed. Default is 0.4.
|
| 641 |
+
device (torch.device, optional): The device to use for computation. Defaults to torch.device("cpu").
|
| 642 |
+
|
| 643 |
+
Returns:
|
| 644 |
+
tuple: A tuple containing the computed masks and the final pixel locations.
|
| 645 |
+
"""
|
| 646 |
+
|
| 647 |
+
if (cellprob > cellprob_threshold).sum(): #mask at this point is a cell cluster binary map, not labels
|
| 648 |
+
inds = np.nonzero(cellprob > cellprob_threshold)
|
| 649 |
+
if len(inds[0]) == 0:
|
| 650 |
+
dynamics_logger.info("No cell pixels found.")
|
| 651 |
+
shape = cellprob.shape
|
| 652 |
+
mask = np.zeros(shape, "uint16")
|
| 653 |
+
return mask
|
| 654 |
+
|
| 655 |
+
p_final = follow_flows(dP * (cellprob > cellprob_threshold) / 5.,
|
| 656 |
+
inds=inds, niter=niter,
|
| 657 |
+
device=device)
|
| 658 |
+
if not torch.is_tensor(p_final):
|
| 659 |
+
p_final = torch.from_numpy(p_final).to(device, dtype=torch.int)
|
| 660 |
+
else:
|
| 661 |
+
p_final = p_final.int()
|
| 662 |
+
# calculate masks
|
| 663 |
+
if device.type == "mps":
|
| 664 |
+
p_final = p_final.to(torch.device("cpu"))
|
| 665 |
+
mask = get_masks_torch(p_final, inds, dP.shape[1:],
|
| 666 |
+
max_size_fraction=max_size_fraction)
|
| 667 |
+
del p_final
|
| 668 |
+
# flow thresholding factored out of get_masks
|
| 669 |
+
if not do_3D:
|
| 670 |
+
if mask.max() > 0 and flow_threshold is not None and flow_threshold > 0:
|
| 671 |
+
# make sure labels are unique at output of get_masks
|
| 672 |
+
mask = remove_bad_flow_masks(mask, dP, threshold=flow_threshold,
|
| 673 |
+
device=device)
|
| 674 |
+
|
| 675 |
+
if mask.max() < 2**16 and mask.dtype != "uint16":
|
| 676 |
+
mask = mask.astype("uint16")
|
| 677 |
+
|
| 678 |
+
else: # nothing to compute, just make it compatible
|
| 679 |
+
dynamics_logger.info("No cell pixels found.")
|
| 680 |
+
shape = cellprob.shape
|
| 681 |
+
mask = np.zeros(cellprob.shape, "uint16")
|
| 682 |
+
return mask
|
| 683 |
+
|
| 684 |
+
if min_size > 0:
|
| 685 |
+
mask = utils.fill_holes_and_remove_small_masks(mask, min_size=min_size)
|
| 686 |
+
|
| 687 |
+
if mask.dtype == np.uint32:
|
| 688 |
+
dynamics_logger.warning(
|
| 689 |
+
"more than 65535 masks in image, masks returned as np.uint32")
|
| 690 |
+
|
| 691 |
+
return mask
|
models/seg_post_model/cellpose/export.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Auxiliary module for bioimageio format export
|
| 2 |
+
|
| 3 |
+
Example usage:
|
| 4 |
+
|
| 5 |
+
```bash
|
| 6 |
+
#!/bin/bash
|
| 7 |
+
|
| 8 |
+
# Define default paths and parameters
|
| 9 |
+
DEFAULT_CHANNELS="1 0"
|
| 10 |
+
DEFAULT_PATH_PRETRAINED_MODEL="/home/qinyu/models/cp/cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995"
|
| 11 |
+
DEFAULT_PATH_README="/home/qinyu/models/cp/README.md"
|
| 12 |
+
DEFAULT_LIST_PATH_COVER_IMAGES="/home/qinyu/images/cp/cellpose_raw_and_segmentation.jpg /home/qinyu/images/cp/cellpose_raw_and_probability.jpg /home/qinyu/images/cp/cellpose_raw.jpg"
|
| 13 |
+
DEFAULT_MODEL_ID="philosophical-panda"
|
| 14 |
+
DEFAULT_MODEL_ICON="🐼"
|
| 15 |
+
DEFAULT_MODEL_VERSION="0.1.0"
|
| 16 |
+
DEFAULT_MODEL_NAME="My Cool Cellpose"
|
| 17 |
+
DEFAULT_MODEL_DOCUMENTATION="A cool Cellpose model trained for my cool dataset."
|
| 18 |
+
DEFAULT_MODEL_AUTHORS='[{"name": "Qin Yu", "affiliation": "EMBL", "github_user": "qin-yu", "orcid": "0000-0002-4652-0795"}]'
|
| 19 |
+
DEFAULT_MODEL_CITE='[{"text": "For more details of the model itself, see the manuscript", "doi": "10.1242/dev.202800", "url": null}]'
|
| 20 |
+
DEFAULT_MODEL_TAGS="cellpose 3d 2d"
|
| 21 |
+
DEFAULT_MODEL_LICENSE="MIT"
|
| 22 |
+
DEFAULT_MODEL_REPO="https://github.com/kreshuklab/go-nuclear"
|
| 23 |
+
|
| 24 |
+
# Run the Python script with default parameters
|
| 25 |
+
python export.py \
|
| 26 |
+
--channels $DEFAULT_CHANNELS \
|
| 27 |
+
--path_pretrained_model "$DEFAULT_PATH_PRETRAINED_MODEL" \
|
| 28 |
+
--path_readme "$DEFAULT_PATH_README" \
|
| 29 |
+
--list_path_cover_images $DEFAULT_LIST_PATH_COVER_IMAGES \
|
| 30 |
+
--model_version "$DEFAULT_MODEL_VERSION" \
|
| 31 |
+
--model_name "$DEFAULT_MODEL_NAME" \
|
| 32 |
+
--model_documentation "$DEFAULT_MODEL_DOCUMENTATION" \
|
| 33 |
+
--model_authors "$DEFAULT_MODEL_AUTHORS" \
|
| 34 |
+
--model_cite "$DEFAULT_MODEL_CITE" \
|
| 35 |
+
--model_tags $DEFAULT_MODEL_TAGS \
|
| 36 |
+
--model_license "$DEFAULT_MODEL_LICENSE" \
|
| 37 |
+
--model_repo "$DEFAULT_MODEL_REPO"
|
| 38 |
+
```
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
import os
|
| 42 |
+
import sys
|
| 43 |
+
import json
|
| 44 |
+
import argparse
|
| 45 |
+
from pathlib import Path
|
| 46 |
+
from urllib.parse import urlparse
|
| 47 |
+
|
| 48 |
+
import torch
|
| 49 |
+
import numpy as np
|
| 50 |
+
|
| 51 |
+
from cellpose.io import imread
|
| 52 |
+
from cellpose.utils import download_url_to_file
|
| 53 |
+
from cellpose.transforms import pad_image_ND, normalize_img, convert_image
|
| 54 |
+
from cellpose.vit_sam import CPnetBioImageIO
|
| 55 |
+
|
| 56 |
+
from bioimageio.spec.model.v0_5 import (
|
| 57 |
+
ArchitectureFromFileDescr,
|
| 58 |
+
Author,
|
| 59 |
+
AxisId,
|
| 60 |
+
ChannelAxis,
|
| 61 |
+
CiteEntry,
|
| 62 |
+
Doi,
|
| 63 |
+
FileDescr,
|
| 64 |
+
Identifier,
|
| 65 |
+
InputTensorDescr,
|
| 66 |
+
IntervalOrRatioDataDescr,
|
| 67 |
+
LicenseId,
|
| 68 |
+
ModelDescr,
|
| 69 |
+
ModelId,
|
| 70 |
+
OrcidId,
|
| 71 |
+
OutputTensorDescr,
|
| 72 |
+
ParameterizedSize,
|
| 73 |
+
PytorchStateDictWeightsDescr,
|
| 74 |
+
SizeReference,
|
| 75 |
+
SpaceInputAxis,
|
| 76 |
+
SpaceOutputAxis,
|
| 77 |
+
TensorId,
|
| 78 |
+
TorchscriptWeightsDescr,
|
| 79 |
+
Version,
|
| 80 |
+
WeightsDescr,
|
| 81 |
+
)
|
| 82 |
+
# Define ARBITRARY_SIZE if it is not available in the module
|
| 83 |
+
try:
|
| 84 |
+
from bioimageio.spec.model.v0_5 import ARBITRARY_SIZE
|
| 85 |
+
except ImportError:
|
| 86 |
+
ARBITRARY_SIZE = ParameterizedSize(min=1, step=1)
|
| 87 |
+
|
| 88 |
+
from bioimageio.spec.common import HttpUrl
|
| 89 |
+
from bioimageio.spec import save_bioimageio_package
|
| 90 |
+
from bioimageio.core import test_model
|
| 91 |
+
|
| 92 |
+
DEFAULT_CHANNELS = [2, 1]
|
| 93 |
+
DEFAULT_NORMALIZE_PARAMS = {
|
| 94 |
+
"axis": -1,
|
| 95 |
+
"lowhigh": None,
|
| 96 |
+
"percentile": None,
|
| 97 |
+
"normalize": True,
|
| 98 |
+
"norm3D": False,
|
| 99 |
+
"sharpen_radius": 0,
|
| 100 |
+
"smooth_radius": 0,
|
| 101 |
+
"tile_norm_blocksize": 0,
|
| 102 |
+
"tile_norm_smooth3D": 1,
|
| 103 |
+
"invert": False,
|
| 104 |
+
}
|
| 105 |
+
IMAGE_URL = "http://www.cellpose.org/static/data/rgb_3D.tif"
|
| 106 |
+
|
| 107 |
+
|
| 108 |
+
def download_and_normalize_image(path_dir_temp, channels=DEFAULT_CHANNELS):
|
| 109 |
+
"""
|
| 110 |
+
Download and normalize image.
|
| 111 |
+
"""
|
| 112 |
+
filename = os.path.basename(urlparse(IMAGE_URL).path)
|
| 113 |
+
path_image = path_dir_temp / filename
|
| 114 |
+
if not path_image.exists():
|
| 115 |
+
sys.stderr.write(f'Downloading: "{IMAGE_URL}" to {path_image}\n')
|
| 116 |
+
download_url_to_file(IMAGE_URL, path_image)
|
| 117 |
+
img = imread(path_image).astype(np.float32)
|
| 118 |
+
img = convert_image(img, channels, channel_axis=1, z_axis=0, do_3D=False, nchan=2)
|
| 119 |
+
img = normalize_img(img, **DEFAULT_NORMALIZE_PARAMS)
|
| 120 |
+
img = np.transpose(img, (0, 3, 1, 2))
|
| 121 |
+
img, _, _ = pad_image_ND(img)
|
| 122 |
+
return img
|
| 123 |
+
|
| 124 |
+
|
| 125 |
+
def load_bioimageio_cpnet_model(path_model_weight, nchan=2):
|
| 126 |
+
cpnet_kwargs = {
|
| 127 |
+
"nout": 3,
|
| 128 |
+
}
|
| 129 |
+
cpnet_biio = CPnetBioImageIO(**cpnet_kwargs)
|
| 130 |
+
state_dict_cuda = torch.load(path_model_weight, map_location=torch.device("cpu"), weights_only=True)
|
| 131 |
+
cpnet_biio.load_state_dict(state_dict_cuda)
|
| 132 |
+
cpnet_biio.eval() # crucial for the prediction results
|
| 133 |
+
return cpnet_biio, cpnet_kwargs
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def descr_gen_input(path_test_input, nchan=2):
|
| 137 |
+
input_axes = [
|
| 138 |
+
SpaceInputAxis(id=AxisId("z"), size=ARBITRARY_SIZE),
|
| 139 |
+
ChannelAxis(channel_names=[Identifier(f"c{i+1}") for i in range(nchan)]),
|
| 140 |
+
SpaceInputAxis(id=AxisId("y"), size=ParameterizedSize(min=16, step=16)),
|
| 141 |
+
SpaceInputAxis(id=AxisId("x"), size=ParameterizedSize(min=16, step=16)),
|
| 142 |
+
]
|
| 143 |
+
data_descr = IntervalOrRatioDataDescr(type="float32")
|
| 144 |
+
path_test_input = Path(path_test_input)
|
| 145 |
+
descr_input = InputTensorDescr(
|
| 146 |
+
id=TensorId("raw"),
|
| 147 |
+
axes=input_axes,
|
| 148 |
+
test_tensor=FileDescr(source=path_test_input),
|
| 149 |
+
data=data_descr,
|
| 150 |
+
)
|
| 151 |
+
return descr_input
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def descr_gen_output_flow(path_test_output):
|
| 155 |
+
output_axes_output_tensor = [
|
| 156 |
+
SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
|
| 157 |
+
ChannelAxis(channel_names=[Identifier("flow1"), Identifier("flow2"), Identifier("flow3")]),
|
| 158 |
+
SpaceOutputAxis(id=AxisId("y"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y"))),
|
| 159 |
+
SpaceOutputAxis(id=AxisId("x"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x"))),
|
| 160 |
+
]
|
| 161 |
+
path_test_output = Path(path_test_output)
|
| 162 |
+
descr_output = OutputTensorDescr(
|
| 163 |
+
id=TensorId("flow"),
|
| 164 |
+
axes=output_axes_output_tensor,
|
| 165 |
+
test_tensor=FileDescr(source=path_test_output),
|
| 166 |
+
)
|
| 167 |
+
return descr_output
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def descr_gen_output_downsampled(path_dir_temp, nbase=None):
|
| 171 |
+
if nbase is None:
|
| 172 |
+
nbase = [32, 64, 128, 256]
|
| 173 |
+
|
| 174 |
+
output_axes_downsampled_tensors = [
|
| 175 |
+
[
|
| 176 |
+
SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
|
| 177 |
+
ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(base)]),
|
| 178 |
+
SpaceOutputAxis(
|
| 179 |
+
id=AxisId("y"),
|
| 180 |
+
size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("y")),
|
| 181 |
+
scale=2**offset,
|
| 182 |
+
),
|
| 183 |
+
SpaceOutputAxis(
|
| 184 |
+
id=AxisId("x"),
|
| 185 |
+
size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("x")),
|
| 186 |
+
scale=2**offset,
|
| 187 |
+
),
|
| 188 |
+
]
|
| 189 |
+
for offset, base in enumerate(nbase)
|
| 190 |
+
]
|
| 191 |
+
path_downsampled_tensors = [
|
| 192 |
+
Path(path_dir_temp / f"test_downsampled_{i}.npy") for i in range(len(output_axes_downsampled_tensors))
|
| 193 |
+
]
|
| 194 |
+
descr_output_downsampled_tensors = [
|
| 195 |
+
OutputTensorDescr(
|
| 196 |
+
id=TensorId(f"downsampled_{i}"),
|
| 197 |
+
axes=axes,
|
| 198 |
+
test_tensor=FileDescr(source=path),
|
| 199 |
+
)
|
| 200 |
+
for i, (axes, path) in enumerate(zip(output_axes_downsampled_tensors, path_downsampled_tensors))
|
| 201 |
+
]
|
| 202 |
+
return descr_output_downsampled_tensors
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def descr_gen_output_style(path_test_style, nchannel=256):
|
| 206 |
+
output_axes_style_tensor = [
|
| 207 |
+
SpaceOutputAxis(id=AxisId("z"), size=SizeReference(tensor_id=TensorId("raw"), axis_id=AxisId("z"))),
|
| 208 |
+
ChannelAxis(channel_names=[Identifier(f"feature{i+1}") for i in range(nchannel)]),
|
| 209 |
+
]
|
| 210 |
+
path_style_tensor = Path(path_test_style)
|
| 211 |
+
descr_output_style_tensor = OutputTensorDescr(
|
| 212 |
+
id=TensorId("style"),
|
| 213 |
+
axes=output_axes_style_tensor,
|
| 214 |
+
test_tensor=FileDescr(source=path_style_tensor),
|
| 215 |
+
)
|
| 216 |
+
return descr_output_style_tensor
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
def descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper=None):
|
| 220 |
+
if path_cpnet_wrapper is None:
|
| 221 |
+
path_cpnet_wrapper = Path(__file__).parent / "resnet_torch.py"
|
| 222 |
+
pytorch_architecture = ArchitectureFromFileDescr(
|
| 223 |
+
callable=Identifier("CPnetBioImageIO"),
|
| 224 |
+
source=Path(path_cpnet_wrapper),
|
| 225 |
+
kwargs=cpnet_kwargs,
|
| 226 |
+
)
|
| 227 |
+
return pytorch_architecture
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def descr_gen_documentation(path_doc, markdown_text):
|
| 231 |
+
with open(path_doc, "w") as f:
|
| 232 |
+
f.write(markdown_text)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def package_to_bioimageio(
|
| 236 |
+
path_pretrained_model,
|
| 237 |
+
path_save_trace,
|
| 238 |
+
path_readme,
|
| 239 |
+
list_path_cover_images,
|
| 240 |
+
descr_input,
|
| 241 |
+
descr_output,
|
| 242 |
+
descr_output_downsampled_tensors,
|
| 243 |
+
descr_output_style_tensor,
|
| 244 |
+
pytorch_version,
|
| 245 |
+
pytorch_architecture,
|
| 246 |
+
model_id,
|
| 247 |
+
model_icon,
|
| 248 |
+
model_version,
|
| 249 |
+
model_name,
|
| 250 |
+
model_documentation,
|
| 251 |
+
model_authors,
|
| 252 |
+
model_cite,
|
| 253 |
+
model_tags,
|
| 254 |
+
model_license,
|
| 255 |
+
model_repo,
|
| 256 |
+
):
|
| 257 |
+
"""Package model description to BioImage.IO format."""
|
| 258 |
+
my_model_descr = ModelDescr(
|
| 259 |
+
id=ModelId(model_id) if model_id is not None else None,
|
| 260 |
+
id_emoji=model_icon,
|
| 261 |
+
version=Version(model_version),
|
| 262 |
+
name=model_name,
|
| 263 |
+
description=model_documentation,
|
| 264 |
+
authors=[
|
| 265 |
+
Author(
|
| 266 |
+
name=author["name"],
|
| 267 |
+
affiliation=author["affiliation"],
|
| 268 |
+
github_user=author["github_user"],
|
| 269 |
+
orcid=OrcidId(author["orcid"]),
|
| 270 |
+
)
|
| 271 |
+
for author in model_authors
|
| 272 |
+
],
|
| 273 |
+
cite=[CiteEntry(text=cite["text"], doi=Doi(cite["doi"]), url=cite["url"]) for cite in model_cite],
|
| 274 |
+
covers=[Path(img) for img in list_path_cover_images],
|
| 275 |
+
license=LicenseId(model_license),
|
| 276 |
+
tags=model_tags,
|
| 277 |
+
documentation=Path(path_readme),
|
| 278 |
+
git_repo=HttpUrl(model_repo),
|
| 279 |
+
inputs=[descr_input],
|
| 280 |
+
outputs=[descr_output, descr_output_style_tensor] + descr_output_downsampled_tensors,
|
| 281 |
+
weights=WeightsDescr(
|
| 282 |
+
pytorch_state_dict=PytorchStateDictWeightsDescr(
|
| 283 |
+
source=Path(path_pretrained_model),
|
| 284 |
+
architecture=pytorch_architecture,
|
| 285 |
+
pytorch_version=pytorch_version,
|
| 286 |
+
),
|
| 287 |
+
torchscript=TorchscriptWeightsDescr(
|
| 288 |
+
source=Path(path_save_trace),
|
| 289 |
+
pytorch_version=pytorch_version,
|
| 290 |
+
parent="pytorch_state_dict", # these weights were converted from the pytorch_state_dict weights.
|
| 291 |
+
),
|
| 292 |
+
),
|
| 293 |
+
)
|
| 294 |
+
|
| 295 |
+
return my_model_descr
|
| 296 |
+
|
| 297 |
+
|
| 298 |
+
def parse_args():
|
| 299 |
+
# fmt: off
|
| 300 |
+
parser = argparse.ArgumentParser(description="BioImage.IO model packaging for Cellpose")
|
| 301 |
+
parser.add_argument("--channels", nargs=2, default=[2, 1], type=int, help="Cyto-only = [2, 0], Cyto + Nuclei = [2, 1], Nuclei-only = [1, 0]")
|
| 302 |
+
parser.add_argument("--path_pretrained_model", required=True, type=str, help="Path to pretrained model file, e.g., cellpose_residual_on_style_on_concatenation_off_1135_rest_2023_05_04_23_41_31.252995")
|
| 303 |
+
parser.add_argument("--path_readme", required=True, type=str, help="Path to README file")
|
| 304 |
+
parser.add_argument("--list_path_cover_images", nargs='+', required=True, type=str, help="List of paths to cover images")
|
| 305 |
+
parser.add_argument("--model_id", type=str, help="Model ID, provide if already exists", default=None)
|
| 306 |
+
parser.add_argument("--model_icon", type=str, help="Model icon, provide if already exists", default=None)
|
| 307 |
+
parser.add_argument("--model_version", required=True, type=str, help="Model version, new model should be 0.1.0")
|
| 308 |
+
parser.add_argument("--model_name", required=True, type=str, help="Model name, e.g., My Cool Cellpose")
|
| 309 |
+
parser.add_argument("--model_documentation", required=True, type=str, help="Model documentation, e.g., A cool Cellpose model trained for my cool dataset.")
|
| 310 |
+
parser.add_argument("--model_authors", required=True, type=str, help="Model authors in JSON format, e.g., '[{\"name\": \"Qin Yu\", \"affiliation\": \"EMBL\", \"github_user\": \"qin-yu\", \"orcid\": \"0000-0002-4652-0795\"}]'")
|
| 311 |
+
parser.add_argument("--model_cite", required=True, type=str, help="Model citation in JSON format, e.g., '[{\"text\": \"For more details of the model itself, see the manuscript\", \"doi\": \"10.1242/dev.202800\", \"url\": null}]'")
|
| 312 |
+
parser.add_argument("--model_tags", nargs='+', required=True, type=str, help="Model tags, e.g., cellpose 3d 2d")
|
| 313 |
+
parser.add_argument("--model_license", required=True, type=str, help="Model license, e.g., MIT")
|
| 314 |
+
parser.add_argument("--model_repo", required=True, type=str, help="Model repository URL")
|
| 315 |
+
return parser.parse_args()
|
| 316 |
+
# fmt: on
|
| 317 |
+
|
| 318 |
+
|
| 319 |
+
def main():
|
| 320 |
+
args = parse_args()
|
| 321 |
+
|
| 322 |
+
# Parse user-provided paths and arguments
|
| 323 |
+
channels = args.channels
|
| 324 |
+
model_cite = json.loads(args.model_cite)
|
| 325 |
+
model_authors = json.loads(args.model_authors)
|
| 326 |
+
|
| 327 |
+
path_readme = Path(args.path_readme)
|
| 328 |
+
path_pretrained_model = Path(args.path_pretrained_model)
|
| 329 |
+
list_path_cover_images = [Path(path_image) for path_image in args.list_path_cover_images]
|
| 330 |
+
|
| 331 |
+
# Auto-generated paths
|
| 332 |
+
path_cpnet_wrapper = Path(__file__).resolve().parent / "resnet_torch.py"
|
| 333 |
+
path_dir_temp = Path(__file__).resolve().parent.parent / "models" / path_pretrained_model.stem
|
| 334 |
+
path_dir_temp.mkdir(parents=True, exist_ok=True)
|
| 335 |
+
|
| 336 |
+
path_save_trace = path_dir_temp / "cp_traced.pt"
|
| 337 |
+
path_test_input = path_dir_temp / "test_input.npy"
|
| 338 |
+
path_test_output = path_dir_temp / "test_output.npy"
|
| 339 |
+
path_test_style = path_dir_temp / "test_style.npy"
|
| 340 |
+
path_bioimageio_package = path_dir_temp / "cellpose_model.zip"
|
| 341 |
+
|
| 342 |
+
# Download test input image
|
| 343 |
+
img_np = download_and_normalize_image(path_dir_temp, channels=channels)
|
| 344 |
+
np.save(path_test_input, img_np)
|
| 345 |
+
img = torch.tensor(img_np).float()
|
| 346 |
+
|
| 347 |
+
# Load model
|
| 348 |
+
cpnet_biio, cpnet_kwargs = load_bioimageio_cpnet_model(path_pretrained_model)
|
| 349 |
+
|
| 350 |
+
# Test model and save output
|
| 351 |
+
tuple_output_tensor = cpnet_biio(img)
|
| 352 |
+
np.save(path_test_output, tuple_output_tensor[0].detach().numpy())
|
| 353 |
+
np.save(path_test_style, tuple_output_tensor[1].detach().numpy())
|
| 354 |
+
for i, t in enumerate(tuple_output_tensor[2:]):
|
| 355 |
+
np.save(path_dir_temp / f"test_downsampled_{i}.npy", t.detach().numpy())
|
| 356 |
+
|
| 357 |
+
# Save traced model
|
| 358 |
+
model_traced = torch.jit.trace(cpnet_biio, img)
|
| 359 |
+
model_traced.save(path_save_trace)
|
| 360 |
+
|
| 361 |
+
# Generate model description
|
| 362 |
+
descr_input = descr_gen_input(path_test_input)
|
| 363 |
+
descr_output = descr_gen_output_flow(path_test_output)
|
| 364 |
+
descr_output_downsampled_tensors = descr_gen_output_downsampled(path_dir_temp, nbase=cpnet_biio.nbase[1:])
|
| 365 |
+
descr_output_style_tensor = descr_gen_output_style(path_test_style, cpnet_biio.nbase[-1])
|
| 366 |
+
pytorch_version = Version(torch.__version__)
|
| 367 |
+
pytorch_architecture = descr_gen_arch(cpnet_kwargs, path_cpnet_wrapper)
|
| 368 |
+
|
| 369 |
+
# Package model
|
| 370 |
+
my_model_descr = package_to_bioimageio(
|
| 371 |
+
path_pretrained_model,
|
| 372 |
+
path_save_trace,
|
| 373 |
+
path_readme,
|
| 374 |
+
list_path_cover_images,
|
| 375 |
+
descr_input,
|
| 376 |
+
descr_output,
|
| 377 |
+
descr_output_downsampled_tensors,
|
| 378 |
+
descr_output_style_tensor,
|
| 379 |
+
pytorch_version,
|
| 380 |
+
pytorch_architecture,
|
| 381 |
+
args.model_id,
|
| 382 |
+
args.model_icon,
|
| 383 |
+
args.model_version,
|
| 384 |
+
args.model_name,
|
| 385 |
+
args.model_documentation,
|
| 386 |
+
model_authors,
|
| 387 |
+
model_cite,
|
| 388 |
+
args.model_tags,
|
| 389 |
+
args.model_license,
|
| 390 |
+
args.model_repo,
|
| 391 |
+
)
|
| 392 |
+
|
| 393 |
+
# Test model
|
| 394 |
+
summary = test_model(my_model_descr, weight_format="pytorch_state_dict")
|
| 395 |
+
summary.display()
|
| 396 |
+
summary = test_model(my_model_descr, weight_format="torchscript")
|
| 397 |
+
summary.display()
|
| 398 |
+
|
| 399 |
+
# Save BioImage.IO package
|
| 400 |
+
package_path = save_bioimageio_package(my_model_descr, output_path=Path(path_bioimageio_package))
|
| 401 |
+
print("package path:", package_path)
|
| 402 |
+
|
| 403 |
+
|
| 404 |
+
if __name__ == "__main__":
|
| 405 |
+
main()
|