Spaces:
Sleeping
Sleeping
File size: 7,341 Bytes
aff3c6f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 |
"""
Copyright © 2025 Howard Hughes Medical Institute, Authored by Carsen Stringer and Marius Pachitariu.
"""
import torch
from segment_anything import sam_model_registry
torch.backends.cuda.matmul.allow_tf32 = True
from torch import nn
import torch.nn.functional as F
class Transformer(nn.Module):
def __init__(self, backbone="vit_l", ps=8, nout=3, bsize=256, rdrop=0.4,
checkpoint=None, dtype=torch.float32):
super(Transformer, self).__init__()
"""
print(self.encoder.patch_embed)
PatchEmbed(
(proj): Conv2d(3, 1024, kernel_size=(16, 16), stride=(16, 16))
)
print(self.encoder.neck)
Sequential(
(0): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1), bias=False)
(1): LayerNorm2d()
(2): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
(3): LayerNorm2d()
)
"""
# instantiate the vit model, default to not loading SAM
# checkpoint = sam_vit_l_0b3195.pth is standard pretrained SAM
self.encoder = sam_model_registry[backbone](checkpoint).image_encoder
w = self.encoder.patch_embed.proj.weight.detach()
nchan = w.shape[0]
# change token size to ps x ps
self.ps = ps
self.encoder.patch_embed.proj = nn.Conv2d(3, nchan, stride=ps, kernel_size=ps)
self.encoder.patch_embed.proj.weight.data = w[:,:,::16//ps,::16//ps]
# adjust position embeddings for new bsize and new token size
ds = (1024 // 16) // (bsize // ps)
self.encoder.pos_embed = nn.Parameter(self.encoder.pos_embed[:,::ds,::ds], requires_grad=True)
# readout weights for nout output channels
# if nout is changed, weights will not load correctly from pretrained Cellpose-SAM
self.nout = nout
self.out = nn.Conv2d(256, self.nout * ps**2, kernel_size=1)
# W2 reshapes token space to pixel space, not trainable
self.W2 = nn.Parameter(torch.eye(self.nout * ps**2).reshape(self.nout*ps**2, self.nout, ps, ps),
requires_grad=False)
# fraction of layers to drop at random during training
self.rdrop = rdrop
# average diameter of ROIs from training images from fine-tuning
self.diam_labels = nn.Parameter(torch.tensor([30.]), requires_grad=False)
# average diameter of ROIs during main training
self.diam_mean = nn.Parameter(torch.tensor([30.]), requires_grad=False)
# set attention to global in every layer
for blk in self.encoder.blocks:
blk.window_size = 0
self.dtype = dtype
def forward(self, x, feat=None):
# same progression as SAM until readout
x = self.encoder.patch_embed(x)
if feat is not None:
feat = self.encoder.patch_embed(feat)
x = x + x * feat * 0.5
if self.encoder.pos_embed is not None:
x = x + self.encoder.pos_embed
if self.training and self.rdrop > 0:
nlay = len(self.encoder.blocks)
rdrop = (torch.rand((len(x), nlay), device=x.device) <
torch.linspace(0, self.rdrop, nlay, device=x.device)).to(x.dtype)
for i, blk in enumerate(self.encoder.blocks):
mask = rdrop[:,i].unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
x = x * mask + blk(x) * (1-mask)
else:
for blk in self.encoder.blocks:
x = blk(x)
x = self.encoder.neck(x.permute(0, 3, 1, 2))
# readout is changed here
x1 = self.out(x)
x1 = F.conv_transpose2d(x1, self.W2, stride = self.ps, padding = 0)
# maintain the second output of feature size 256 for backwards compatibility
return x1, torch.randn((x.shape[0], 256), device=x.device)
def load_model(self, PATH, device, strict = False):
state_dict = torch.load(PATH, map_location = device, weights_only=True)
keys = [k for k in state_dict.keys()]
if keys[0][:7] == "module.":
from collections import OrderedDict
new_state_dict = OrderedDict()
for k, v in state_dict.items():
name = k[7:] # remove 'module.' of DataParallel/DistributedDataParallel
new_state_dict[name] = v
self.load_state_dict(new_state_dict, strict = strict)
else:
self.load_state_dict(state_dict, strict = strict)
if self.dtype != torch.float32:
self = self.to(self.dtype)
@property
def device(self):
"""
Get the device of the model.
Returns:
torch.device: The device of the model.
"""
return next(self.parameters()).device
def save_model(self, filename):
"""
Save the model to a file.
Args:
filename (str): The path to the file where the model will be saved.
"""
torch.save(self.state_dict(), filename)
class CPnetBioImageIO(Transformer):
"""
A subclass of the CP-SAM model compatible with the BioImage.IO Spec.
This subclass addresses the limitation of CPnet's incompatibility with the BioImage.IO Spec,
allowing the CPnet model to use the weights uploaded to the BioImage.IO Model Zoo.
"""
def forward(self, x):
"""
Perform a forward pass of the CPnet model and return unpacked tensors.
Args:
x (torch.Tensor): Input tensor.
Returns:
tuple: A tuple containing the output tensor, style tensor, and downsampled tensors.
"""
output_tensor, style_tensor, downsampled_tensors = super().forward(x)
return output_tensor, style_tensor, *downsampled_tensors
def load_model(self, filename, device=None):
"""
Load the model from a file.
Args:
filename (str): The path to the file where the model is saved.
device (torch.device, optional): The device to load the model on. Defaults to None.
"""
if (device is not None) and (device.type != "cpu"):
state_dict = torch.load(filename, map_location=device, weights_only=True)
else:
self.__init__(self.nout)
state_dict = torch.load(filename, map_location=torch.device("cpu"),
weights_only=True)
self.load_state_dict(state_dict)
def load_state_dict(self, state_dict):
"""
Load the state dictionary into the model.
This method overrides the default `load_state_dict` to handle Cellpose's custom
loading mechanism and ensures compatibility with BioImage.IO Core.
Args:
state_dict (Mapping[str, Any]): A state dictionary to load into the model
"""
if state_dict["output.2.weight"].shape[0] != self.nout:
for name in self.state_dict():
if "output" not in name:
self.state_dict()[name].copy_(state_dict[name])
else:
super().load_state_dict(
{name: param for name, param in state_dict.items()},
strict=False)
|