|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
|
|
|
|
import os |
|
|
import sys |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import numpy as np |
|
|
|
|
|
import timm |
|
|
from timm.models.registry import register_model |
|
|
import huggingface_hub |
|
|
from transformers import BertModel, BertConfig |
|
|
|
|
|
from .pos_embed import get_2d_sincos_pos_embed |
|
|
from .torchscale.model.LongNet import make_longnet_from_name |
|
|
|
|
|
|
|
|
class Reducer(nn.Module): |
|
|
"""Instruct Embedding""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
in_chans=1536, |
|
|
embed_dim=768, |
|
|
norm_layer=None, |
|
|
bias=True, |
|
|
): |
|
|
super().__init__() |
|
|
|
|
|
self.proj = nn.Linear(in_chans, embed_dim, bias=bias) |
|
|
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
|
|
|
def forward(self, x): |
|
|
B, L, D = x.shape |
|
|
x = self.proj(x) |
|
|
x = self.norm(x) |
|
|
return x |
|
|
|
|
|
class CrossAttention(nn.Module): |
|
|
def __init__(self, embed_dim, num_heads): |
|
|
super(CrossAttention, self).__init__() |
|
|
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) |
|
|
|
|
|
def forward(self, query, key, value, key_padding_mask=None): |
|
|
|
|
|
|
|
|
|
|
|
output, attn_weights = self.attn(query, key, value, key_padding_mask=key_padding_mask) |
|
|
return output, attn_weights |
|
|
|
|
|
|
|
|
class LongNetViT(nn.Module): |
|
|
""" |
|
|
Backbone of Vision Transformer for downstream tasks |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
in_chans: int |
|
|
The number of input channels, should be the llm encoding dimension 4096. |
|
|
embed_dim: int |
|
|
The embedding dimension of the LongNet model. |
|
|
depth: int |
|
|
The number of LongNet layers in the LongNet model. |
|
|
slide_ngrids: int |
|
|
The number of grids in the slide. |
|
|
tile_size: int |
|
|
The tile size. Default is 256px. |
|
|
max_wsi_size: int |
|
|
The maximum size of the WSI. |
|
|
norm_layer: nn.LayerNorm |
|
|
The normalization layer used in the model. |
|
|
global_pool: bool |
|
|
Whether to use global pooling or not. |
|
|
dropout: float |
|
|
The dropout rate used in the model. |
|
|
drop_path_rate: float |
|
|
The drop path rate used in the model. |
|
|
num_layers: int |
|
|
The number of stacked "encoder and xatten" |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
in_chans=4096, |
|
|
embed_dim=512, |
|
|
depth=12, |
|
|
slide_ngrids=1000, |
|
|
tile_size=256, |
|
|
max_wsi_size=262144, |
|
|
norm_layer=nn.LayerNorm, |
|
|
dropout=0.25, |
|
|
drop_path_rate=0.1, |
|
|
num_layers = 2, |
|
|
num_heads = 8, |
|
|
**kwargs): |
|
|
super().__init__() |
|
|
|
|
|
|
|
|
print("####Vision-Text Interaction based Adaptors (Longnet) ####") |
|
|
|
|
|
self.slide_ngrids = slide_ngrids |
|
|
num_patches = slide_ngrids**2 |
|
|
|
|
|
self.register_buffer('pos_embed', torch.zeros(1, num_patches, embed_dim), persistent=True) |
|
|
self.num_layers = num_layers |
|
|
|
|
|
self.encoder_name = "LongNet_{}_layers_{}_dim".format(depth, embed_dim) |
|
|
if kwargs.get("mlp_ratio", 4.0) != 4.0: |
|
|
self.encoder_name += "_mlp{}".format(kwargs.get("mlp_ratio")) |
|
|
|
|
|
config_self = BertConfig( |
|
|
hidden_size=embed_dim, |
|
|
num_attention_heads=num_heads, |
|
|
num_hidden_layers=1 |
|
|
) |
|
|
|
|
|
|
|
|
segment_length,dilated_ratio = self.get_optimal_segment_length(max_wsi_size, tile_size) |
|
|
|
|
|
self.encoder_wsi = nn.ModuleList([make_longnet_from_name(self.encoder_name, drop_path_rate=drop_path_rate, dropout=dropout,segment_length=segment_length,dilated_ratio=dilated_ratio) |
|
|
for _ in range(num_layers)]) |
|
|
|
|
|
|
|
|
|
|
|
self.self_attention = nn.ModuleList([nn.MultiheadAttention(embed_dim,num_heads, batch_first=True) for _ in range(num_layers)]) |
|
|
self.cross_attention = nn.ModuleList([CrossAttention(embed_dim, num_heads) for _ in range(num_layers)]) |
|
|
|
|
|
print("self_attention:",sum(p.numel() for p in self.self_attention.parameters() if p.requires_grad)) |
|
|
print("CROSS_attentiion:",sum(p.numel() for p in self.cross_attention.parameters() if p.requires_grad)) |
|
|
|
|
|
|
|
|
self.norm = nn.ModuleList([norm_layer(embed_dim) for _ in range(num_layers)]) |
|
|
|
|
|
|
|
|
self.initialize_vit_weights() |
|
|
|
|
|
def initialize_vit_weights(self): |
|
|
|
|
|
|
|
|
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.slide_ngrids, cls_token=False) |
|
|
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.apply(self._init_weights) |
|
|
|
|
|
def get_optimal_segment_length(self, max_wsi_size: int=262144, tile_size: int=256) -> str: |
|
|
''' |
|
|
Get the optimal segment length based on the maximum image size and tile size. |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
max_wsi_size: int |
|
|
The maximum size of the WSI. |
|
|
tile_size: int |
|
|
The tile size. |
|
|
''' |
|
|
max_seq_len = (max_wsi_size // tile_size) ** 2 |
|
|
|
|
|
segment_length = np.linspace(np.log2(1024), int(np.log2(max_seq_len)), 5) |
|
|
segment_length = np.power(2, segment_length).astype(int) |
|
|
dilated_ratio = str([2**i for i in range(len(segment_length))]) |
|
|
|
|
|
segment_length = str(list(segment_length)) |
|
|
return segment_length,dilated_ratio |
|
|
|
|
|
def _init_weights(self, m): |
|
|
if isinstance(m, nn.Linear): |
|
|
|
|
|
torch.nn.init.xavier_uniform_(m.weight) |
|
|
if isinstance(m, nn.Linear) and m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.LayerNorm): |
|
|
nn.init.constant_(m.bias, 0) |
|
|
nn.init.constant_(m.weight, 1.0) |
|
|
|
|
|
def coords_to_pos(self, coords, patch_size=256.0): |
|
|
""" |
|
|
This function is used to convert the coordinates to the positional indices |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
coords: torch.Tensor |
|
|
The coordinates of the patches, of shape [N, L, 2] |
|
|
output: torch.Tensor |
|
|
The positional indices of the patches, of shape [N, L] |
|
|
""" |
|
|
coords_ = torch.floor(coords / patch_size) |
|
|
pos = coords_[..., 0] * self.slide_ngrids + coords_[..., 1] |
|
|
return pos.long() |
|
|
|
|
|
def forward(self, querys , contexts, instructs, coords, patch_size=256.0, self_attention_mask=None, key_padding_mask=None,slide_id=None,level=None): |
|
|
""" |
|
|
The forward pass of the model |
|
|
|
|
|
Arguments: |
|
|
---------- |
|
|
contexts: torch.Tensor |
|
|
The input tile embeddings, of shape [N, L, D] |
|
|
coords: torch.Tensor |
|
|
The coordinates of the patches, of shape [N, L, 2] |
|
|
""" |
|
|
|
|
|
query_length = querys.shape[1] |
|
|
|
|
|
if self_attention_mask is not None: |
|
|
|
|
|
query_mask_extension = torch.ones((querys.shape[0], querys.shape[1]), dtype=self_attention_mask.dtype, device=self_attention_mask.device) |
|
|
self_attention_mask = torch.cat((query_mask_extension, self_attention_mask), dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
pos = self.coords_to_pos(coords=coords, patch_size=patch_size) |
|
|
contexts = contexts*np.sqrt(512) + self.pos_embed[:, pos, :].squeeze(0) |
|
|
|
|
|
|
|
|
|
|
|
for i in range(self.num_layers): |
|
|
|
|
|
contexts = self.encoder_wsi[i](src_tokens=None, token_embeddings=contexts, encoder_padding_mask=key_padding_mask)["encoder_out"] |
|
|
|
|
|
combined_querys = torch.cat((querys, instructs), dim=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self_attn_output,query_text_weights = self.self_attention[i]( |
|
|
combined_querys,combined_querys,combined_querys, |
|
|
key_padding_mask=~self_attention_mask.bool() |
|
|
) |
|
|
querys = self_attn_output[:, :query_length, :] |
|
|
|
|
|
querys = self.norm[i](querys) |
|
|
|
|
|
querys, query_pic_weights = self.cross_attention[i](query=querys, key=contexts, value=contexts, key_padding_mask=key_padding_mask) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
return querys.to(torch.bfloat16) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def create_model(pretrained: str, model_arch: str, in_chans: int,local_dir: str = os.path.join(os.path.expanduser("~"), ".cache/"), **kwargs): |
|
|
model = timm.create_model(model_arch, pretrained=False, in_chans=in_chans, **kwargs) |
|
|
|
|
|
if pretrained.startswith("hf_hub:"): |
|
|
hub_name = pretrained.split(":")[1] |
|
|
huggingface_hub.hf_hub_download(hub_name, filename="slide_encoder.pth", local_dir=local_dir, force_download=True) |
|
|
local_path = os.path.join(local_dir, "slide_encoder.pth") |
|
|
else: |
|
|
local_path = pretrained |
|
|
|
|
|
if os.path.exists(local_path): |
|
|
state_dict = torch.load(local_path, map_location="cpu")["model"] |
|
|
|
|
|
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) |
|
|
if len(missing_keys) > 0: |
|
|
for k in missing_keys: |
|
|
print("Missing ", k) |
|
|
|
|
|
if len(unexpected_keys) > 0: |
|
|
for k in unexpected_keys: |
|
|
print("Unexpected ", k) |
|
|
|
|
|
print("\033[92m Successfully Loaded Pretrained GigaPath model from {} \033[00m".format(pretrained)) |
|
|
else: |
|
|
print("\033[93m Pretrained weights not found at {}. Randomly initialized the model! \033[00m".format(local_path)) |
|
|
|
|
|
return model |
|
|
|
|
|
|
|
|
@register_model |
|
|
def gigapath_slide_enc2l512d(**kwargs): |
|
|
model = LongNetViT(embed_dim=512, depth=2, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16) |
|
|
return model |
|
|
|
|
|
@register_model |
|
|
def gigapath_slide_enc1l512d_level0(**kwargs): |
|
|
model = LongNetViT(embed_dim=512, depth=1, mlp_ratio=4, tile_size=1024,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16) |
|
|
return model |
|
|
|
|
|
@register_model |
|
|
def gigapath_slide_enc1l512d_level1(**kwargs): |
|
|
model = LongNetViT(embed_dim=512, depth=1, mlp_ratio=4, tile_size=2048,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16) |
|
|
return model |
|
|
|
|
|
@register_model |
|
|
def gigapath_slide_enc1l512d_level2(**kwargs): |
|
|
model = LongNetViT(embed_dim=512, depth=1, mlp_ratio=4, tile_size= 4096,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16) |
|
|
return model |
|
|
|