# Copyright (c) Meta Platforms, Inc. and affiliates. # All rights reserved. # # This source code is licensed under the license found in the # LICENSE file in the root directory of this source tree. # -------------------------------------------------------- # References: # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm # DeiT: https://github.com/facebookresearch/deit # MAE: https://github.com/facebookresearch/mae # -------------------------------------------------------- # # Portions Copyright Prov-GigaPath # Original File: https://github.com/facebookresearch/mae from functools import partial import os import sys import torch import torch.nn as nn import numpy as np import timm from timm.models.registry import register_model import huggingface_hub from transformers import BertModel, BertConfig from .pos_embed import get_2d_sincos_pos_embed from .torchscale.model.LongNet import make_longnet_from_name class Reducer(nn.Module): """Instruct Embedding""" def __init__( self, in_chans=1536, embed_dim=768, norm_layer=None, bias=True, ): super().__init__() self.proj = nn.Linear(in_chans, embed_dim, bias=bias) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, L, D = x.shape x = self.proj(x) x = self.norm(x) return x class CrossAttention(nn.Module): def __init__(self, embed_dim, num_heads): super(CrossAttention, self).__init__() self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True) def forward(self, query, key, value, key_padding_mask=None): # query: (batch_size, query_len, embed_dim) # key: (batch_size, key_len, embed_dim) # value: (batch_size, key_len, embed_dim) output, attn_weights = self.attn(query, key, value, key_padding_mask=key_padding_mask) return output, attn_weights class LongNetViT(nn.Module): """ Backbone of Vision Transformer for downstream tasks Arguments: ---------- in_chans: int The number of input channels, should be the llm encoding dimension 4096. embed_dim: int The embedding dimension of the LongNet model. depth: int The number of LongNet layers in the LongNet model. slide_ngrids: int The number of grids in the slide. tile_size: int The tile size. Default is 256px. max_wsi_size: int The maximum size of the WSI. norm_layer: nn.LayerNorm The normalization layer used in the model. global_pool: bool Whether to use global pooling or not. dropout: float The dropout rate used in the model. drop_path_rate: float The drop path rate used in the model. num_layers: int The number of stacked "encoder and xatten" """ def __init__(self, in_chans=4096, embed_dim=512, depth=12, slide_ngrids=1000, tile_size=256, max_wsi_size=262144, norm_layer=nn.LayerNorm, dropout=0.25, drop_path_rate=0.1, num_layers = 2, num_heads = 8, **kwargs): super().__init__() # -------------------------------------------------------------------------- print("####Vision-Text Interaction based Adaptors (Longnet) ####") self.slide_ngrids = slide_ngrids num_patches = slide_ngrids**2 self.register_buffer('pos_embed', torch.zeros(1, num_patches, embed_dim), persistent=True) # fixed sin-cos embedding self.num_layers = num_layers self.encoder_name = "LongNet_{}_layers_{}_dim".format(depth, embed_dim) if kwargs.get("mlp_ratio", 4.0) != 4.0: self.encoder_name += "_mlp{}".format(kwargs.get("mlp_ratio")) config_self = BertConfig( hidden_size=embed_dim, num_attention_heads=num_heads, num_hidden_layers=1 # 只使用一层 Bert ) # get optimal segment length segment_length,dilated_ratio = self.get_optimal_segment_length(max_wsi_size, tile_size) self.encoder_wsi = nn.ModuleList([make_longnet_from_name(self.encoder_name, drop_path_rate=drop_path_rate, dropout=dropout,segment_length=segment_length,dilated_ratio=dilated_ratio) for _ in range(num_layers)]) # self.reduce = Reducer(in_chans, embed_dim) # self.self_attention = nn.ModuleList([BertModel(config_self) for _ in range(num_layers)])#1001 self.self_attention = nn.ModuleList([nn.MultiheadAttention(embed_dim,num_heads, batch_first=True) for _ in range(num_layers)]) self.cross_attention = nn.ModuleList([CrossAttention(embed_dim, num_heads) for _ in range(num_layers)]) print("self_attention:",sum(p.numel() for p in self.self_attention.parameters() if p.requires_grad)) print("CROSS_attentiion:",sum(p.numel() for p in self.cross_attention.parameters() if p.requires_grad)) # self.encoder2 = make_longnet_from_name(self.encoder_name, drop_path_rate=drop_path_rate, dropout=dropout, segment_length=segment_length) self.norm = nn.ModuleList([norm_layer(embed_dim) for _ in range(num_layers)]) # -------------------------------------------------------------------------- self.initialize_vit_weights() def initialize_vit_weights(self): # initialization # initialize (and freeze) pos_embed by sin-cos embedding pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.slide_ngrids, cls_token=False) self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0)) # initialize reduce like nn.Linear (instead of nn.Conv2d) # w = self.reduce.proj.weight.data # torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1])) # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.) # torch.nn.init.normal_(self.cls_token, std=0.02) # initialize nn.Linear and nn.LayerNorm self.apply(self._init_weights) def get_optimal_segment_length(self, max_wsi_size: int=262144, tile_size: int=256) -> str: ''' Get the optimal segment length based on the maximum image size and tile size. Arguments: ---------- max_wsi_size: int The maximum size of the WSI. tile_size: int The tile size. ''' max_seq_len = (max_wsi_size // tile_size) ** 2 # calculate the segment length segment_length = np.linspace(np.log2(1024), int(np.log2(max_seq_len)), 5) segment_length = np.power(2, segment_length).astype(int) dilated_ratio = str([2**i for i in range(len(segment_length))]) # convert to str format segment_length = str(list(segment_length)) return segment_length,dilated_ratio def _init_weights(self, m): if isinstance(m, nn.Linear): # we use xavier_uniform following official JAX ViT: torch.nn.init.xavier_uniform_(m.weight) if isinstance(m, nn.Linear) and m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.LayerNorm): nn.init.constant_(m.bias, 0) nn.init.constant_(m.weight, 1.0) def coords_to_pos(self, coords, patch_size=256.0): """ This function is used to convert the coordinates to the positional indices Arguments: ---------- coords: torch.Tensor The coordinates of the patches, of shape [N, L, 2] output: torch.Tensor The positional indices of the patches, of shape [N, L] """ coords_ = torch.floor(coords / patch_size) pos = coords_[..., 0] * self.slide_ngrids + coords_[..., 1] return pos.long() # + 1 # add 1 for the cls token def forward(self, querys , contexts, instructs, coords, patch_size=256.0, self_attention_mask=None, key_padding_mask=None,slide_id=None,level=None): """ The forward pass of the model Arguments: ---------- contexts: torch.Tensor The input tile embeddings, of shape [N, L, D] coords: torch.Tensor The coordinates of the patches, of shape [N, L, 2] """ query_length = querys.shape[1] # mask for self attn if self_attention_mask is not None: # instruct_tensor requires a mask of the same size filled with ones query_mask_extension = torch.ones((querys.shape[0], querys.shape[1]), dtype=self_attention_mask.dtype, device=self_attention_mask.device) self_attention_mask = torch.cat((query_mask_extension, self_attention_mask), dim=-1) # embed patches # get pos indices pos = self.coords_to_pos(coords=coords, patch_size=patch_size) # [N, L] contexts = contexts*np.sqrt(512) + self.pos_embed[:, pos, :].squeeze(0) # embed instruct, 4096 -> 512 # instructs = self.reduce(instructs) for i in range(self.num_layers): # longnet for wsi tokens contexts = self.encoder_wsi[i](src_tokens=None, token_embeddings=contexts, encoder_padding_mask=key_padding_mask)["encoder_out"] # [:,1:,:]#use transformer 1001 # self-attention for querys and instructs interaction combined_querys = torch.cat((querys, instructs), dim=1) # self_attn_output = self.self_attention[i]( # inputs_embeds=combined_querys, # attention_mask=self_attention_mask # ).last_hidden_state self_attn_output,query_text_weights = self.self_attention[i]( combined_querys,combined_querys,combined_querys, key_padding_mask=~self_attention_mask.bool() ) querys = self_attn_output[:, :query_length, :] # keep query vector only # norm querys querys = self.norm[i](querys) # Cross-Attention: key_padding_mask for padded patch tokens querys, query_pic_weights = self.cross_attention[i](query=querys, key=contexts, value=contexts, key_padding_mask=key_padding_mask) # if i ==self.num_layers-1: # num = 0 # while os.path.exists(f"output/visualization_question_close/{slide_id}_level{level}_querypicweights_{num}.pt"): # num+=1 # torch.save(query_pic_weights,f"output/visualization_question_close/{slide_id}_level{level}_querypicweights_{num}.pt") return querys.to(torch.bfloat16) # outcomes = [] # print("x_list:",len(x_list)) # for x in x_list: # print("before pool:",x.shape) # if self.global_pool: # x = x[:, 1:, :].mean(dim=1) # global average pooling # outcome = self.norm(x) # print("after pool:",x.shape) # else: # x = self.norm(x) # outcome = x[:, 0] # outcomes.append(outcome) # return outcomes def create_model(pretrained: str, model_arch: str, in_chans: int,local_dir: str = os.path.join(os.path.expanduser("~"), ".cache/"), **kwargs): model = timm.create_model(model_arch, pretrained=False, in_chans=in_chans, **kwargs) if pretrained.startswith("hf_hub:"): hub_name = pretrained.split(":")[1] huggingface_hub.hf_hub_download(hub_name, filename="slide_encoder.pth", local_dir=local_dir, force_download=True) local_path = os.path.join(local_dir, "slide_encoder.pth") else: local_path = pretrained if os.path.exists(local_path): state_dict = torch.load(local_path, map_location="cpu")["model"] missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False) if len(missing_keys) > 0: for k in missing_keys: print("Missing ", k) if len(unexpected_keys) > 0: for k in unexpected_keys: print("Unexpected ", k) print("\033[92m Successfully Loaded Pretrained GigaPath model from {} \033[00m".format(pretrained)) else: print("\033[93m Pretrained weights not found at {}. Randomly initialized the model! \033[00m".format(local_path)) return model @register_model def gigapath_slide_enc2l512d(**kwargs): model = LongNetViT(embed_dim=512, depth=2, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16) return model @register_model def gigapath_slide_enc1l512d_level0(**kwargs): model = LongNetViT(embed_dim=512, depth=1, mlp_ratio=4, tile_size=1024,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16) return model @register_model def gigapath_slide_enc1l512d_level1(**kwargs): model = LongNetViT(embed_dim=512, depth=1, mlp_ratio=4, tile_size=2048,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16) return model @register_model def gigapath_slide_enc1l512d_level2(**kwargs): model = LongNetViT(embed_dim=512, depth=1, mlp_ratio=4, tile_size= 4096,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16) return model