English
Llama-slideQA / gigapath /slide_encoder.py
weiheng-1009's picture
added code for running
cbff41a
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
# --------------------------------------------------------
# References:
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
# DeiT: https://github.com/facebookresearch/deit
# MAE: https://github.com/facebookresearch/mae
# --------------------------------------------------------
#
# Portions Copyright Prov-GigaPath
# Original File: https://github.com/facebookresearch/mae
from functools import partial
import os
import sys
import torch
import torch.nn as nn
import numpy as np
import timm
from timm.models.registry import register_model
import huggingface_hub
from transformers import BertModel, BertConfig
from .pos_embed import get_2d_sincos_pos_embed
from .torchscale.model.LongNet import make_longnet_from_name
class Reducer(nn.Module):
"""Instruct Embedding"""
def __init__(
self,
in_chans=1536,
embed_dim=768,
norm_layer=None,
bias=True,
):
super().__init__()
self.proj = nn.Linear(in_chans, embed_dim, bias=bias)
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
def forward(self, x):
B, L, D = x.shape
x = self.proj(x)
x = self.norm(x)
return x
class CrossAttention(nn.Module):
def __init__(self, embed_dim, num_heads):
super(CrossAttention, self).__init__()
self.attn = nn.MultiheadAttention(embed_dim, num_heads, batch_first=True)
def forward(self, query, key, value, key_padding_mask=None):
# query: (batch_size, query_len, embed_dim)
# key: (batch_size, key_len, embed_dim)
# value: (batch_size, key_len, embed_dim)
output, attn_weights = self.attn(query, key, value, key_padding_mask=key_padding_mask)
return output, attn_weights
class LongNetViT(nn.Module):
"""
Backbone of Vision Transformer for downstream tasks
Arguments:
----------
in_chans: int
The number of input channels, should be the llm encoding dimension 4096.
embed_dim: int
The embedding dimension of the LongNet model.
depth: int
The number of LongNet layers in the LongNet model.
slide_ngrids: int
The number of grids in the slide.
tile_size: int
The tile size. Default is 256px.
max_wsi_size: int
The maximum size of the WSI.
norm_layer: nn.LayerNorm
The normalization layer used in the model.
global_pool: bool
Whether to use global pooling or not.
dropout: float
The dropout rate used in the model.
drop_path_rate: float
The drop path rate used in the model.
num_layers: int
The number of stacked "encoder and xatten"
"""
def __init__(self,
in_chans=4096,
embed_dim=512,
depth=12,
slide_ngrids=1000,
tile_size=256,
max_wsi_size=262144,
norm_layer=nn.LayerNorm,
dropout=0.25,
drop_path_rate=0.1,
num_layers = 2,
num_heads = 8,
**kwargs):
super().__init__()
# --------------------------------------------------------------------------
print("####Vision-Text Interaction based Adaptors (Longnet) ####")
self.slide_ngrids = slide_ngrids
num_patches = slide_ngrids**2
self.register_buffer('pos_embed', torch.zeros(1, num_patches, embed_dim), persistent=True) # fixed sin-cos embedding
self.num_layers = num_layers
self.encoder_name = "LongNet_{}_layers_{}_dim".format(depth, embed_dim)
if kwargs.get("mlp_ratio", 4.0) != 4.0:
self.encoder_name += "_mlp{}".format(kwargs.get("mlp_ratio"))
config_self = BertConfig(
hidden_size=embed_dim,
num_attention_heads=num_heads,
num_hidden_layers=1 # 只使用一层 Bert
)
# get optimal segment length
segment_length,dilated_ratio = self.get_optimal_segment_length(max_wsi_size, tile_size)
self.encoder_wsi = nn.ModuleList([make_longnet_from_name(self.encoder_name, drop_path_rate=drop_path_rate, dropout=dropout,segment_length=segment_length,dilated_ratio=dilated_ratio)
for _ in range(num_layers)])
# self.reduce = Reducer(in_chans, embed_dim)
# self.self_attention = nn.ModuleList([BertModel(config_self) for _ in range(num_layers)])#1001
self.self_attention = nn.ModuleList([nn.MultiheadAttention(embed_dim,num_heads, batch_first=True) for _ in range(num_layers)])
self.cross_attention = nn.ModuleList([CrossAttention(embed_dim, num_heads) for _ in range(num_layers)])
print("self_attention:",sum(p.numel() for p in self.self_attention.parameters() if p.requires_grad))
print("CROSS_attentiion:",sum(p.numel() for p in self.cross_attention.parameters() if p.requires_grad))
# self.encoder2 = make_longnet_from_name(self.encoder_name, drop_path_rate=drop_path_rate, dropout=dropout, segment_length=segment_length)
self.norm = nn.ModuleList([norm_layer(embed_dim) for _ in range(num_layers)])
# --------------------------------------------------------------------------
self.initialize_vit_weights()
def initialize_vit_weights(self):
# initialization
# initialize (and freeze) pos_embed by sin-cos embedding
pos_embed = get_2d_sincos_pos_embed(self.pos_embed.shape[-1], self.slide_ngrids, cls_token=False)
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
# initialize reduce like nn.Linear (instead of nn.Conv2d)
# w = self.reduce.proj.weight.data
# torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
# torch.nn.init.normal_(self.cls_token, std=0.02)
# initialize nn.Linear and nn.LayerNorm
self.apply(self._init_weights)
def get_optimal_segment_length(self, max_wsi_size: int=262144, tile_size: int=256) -> str:
'''
Get the optimal segment length based on the maximum image size and tile size.
Arguments:
----------
max_wsi_size: int
The maximum size of the WSI.
tile_size: int
The tile size.
'''
max_seq_len = (max_wsi_size // tile_size) ** 2
# calculate the segment length
segment_length = np.linspace(np.log2(1024), int(np.log2(max_seq_len)), 5)
segment_length = np.power(2, segment_length).astype(int)
dilated_ratio = str([2**i for i in range(len(segment_length))])
# convert to str format
segment_length = str(list(segment_length))
return segment_length,dilated_ratio
def _init_weights(self, m):
if isinstance(m, nn.Linear):
# we use xavier_uniform following official JAX ViT:
torch.nn.init.xavier_uniform_(m.weight)
if isinstance(m, nn.Linear) and m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.LayerNorm):
nn.init.constant_(m.bias, 0)
nn.init.constant_(m.weight, 1.0)
def coords_to_pos(self, coords, patch_size=256.0):
"""
This function is used to convert the coordinates to the positional indices
Arguments:
----------
coords: torch.Tensor
The coordinates of the patches, of shape [N, L, 2]
output: torch.Tensor
The positional indices of the patches, of shape [N, L]
"""
coords_ = torch.floor(coords / patch_size)
pos = coords_[..., 0] * self.slide_ngrids + coords_[..., 1]
return pos.long() # + 1 # add 1 for the cls token
def forward(self, querys , contexts, instructs, coords, patch_size=256.0, self_attention_mask=None, key_padding_mask=None,slide_id=None,level=None):
"""
The forward pass of the model
Arguments:
----------
contexts: torch.Tensor
The input tile embeddings, of shape [N, L, D]
coords: torch.Tensor
The coordinates of the patches, of shape [N, L, 2]
"""
query_length = querys.shape[1]
# mask for self attn
if self_attention_mask is not None:
# instruct_tensor requires a mask of the same size filled with ones
query_mask_extension = torch.ones((querys.shape[0], querys.shape[1]), dtype=self_attention_mask.dtype, device=self_attention_mask.device)
self_attention_mask = torch.cat((query_mask_extension, self_attention_mask), dim=-1)
# embed patches
# get pos indices
pos = self.coords_to_pos(coords=coords, patch_size=patch_size) # [N, L]
contexts = contexts*np.sqrt(512) + self.pos_embed[:, pos, :].squeeze(0)
# embed instruct, 4096 -> 512
# instructs = self.reduce(instructs)
for i in range(self.num_layers):
# longnet for wsi tokens
contexts = self.encoder_wsi[i](src_tokens=None, token_embeddings=contexts, encoder_padding_mask=key_padding_mask)["encoder_out"] # [:,1:,:]#use transformer 1001
# self-attention for querys and instructs interaction
combined_querys = torch.cat((querys, instructs), dim=1)
# self_attn_output = self.self_attention[i](
# inputs_embeds=combined_querys,
# attention_mask=self_attention_mask
# ).last_hidden_state
self_attn_output,query_text_weights = self.self_attention[i](
combined_querys,combined_querys,combined_querys,
key_padding_mask=~self_attention_mask.bool()
)
querys = self_attn_output[:, :query_length, :] # keep query vector only
# norm querys
querys = self.norm[i](querys)
# Cross-Attention: key_padding_mask for padded patch tokens
querys, query_pic_weights = self.cross_attention[i](query=querys, key=contexts, value=contexts, key_padding_mask=key_padding_mask)
# if i ==self.num_layers-1:
# num = 0
# while os.path.exists(f"output/visualization_question_close/{slide_id}_level{level}_querypicweights_{num}.pt"):
# num+=1
# torch.save(query_pic_weights,f"output/visualization_question_close/{slide_id}_level{level}_querypicweights_{num}.pt")
return querys.to(torch.bfloat16)
# outcomes = []
# print("x_list:",len(x_list))
# for x in x_list:
# print("before pool:",x.shape)
# if self.global_pool:
# x = x[:, 1:, :].mean(dim=1) # global average pooling
# outcome = self.norm(x)
# print("after pool:",x.shape)
# else:
# x = self.norm(x)
# outcome = x[:, 0]
# outcomes.append(outcome)
# return outcomes
def create_model(pretrained: str, model_arch: str, in_chans: int,local_dir: str = os.path.join(os.path.expanduser("~"), ".cache/"), **kwargs):
model = timm.create_model(model_arch, pretrained=False, in_chans=in_chans, **kwargs)
if pretrained.startswith("hf_hub:"):
hub_name = pretrained.split(":")[1]
huggingface_hub.hf_hub_download(hub_name, filename="slide_encoder.pth", local_dir=local_dir, force_download=True)
local_path = os.path.join(local_dir, "slide_encoder.pth")
else:
local_path = pretrained
if os.path.exists(local_path):
state_dict = torch.load(local_path, map_location="cpu")["model"]
missing_keys, unexpected_keys = model.load_state_dict(state_dict, strict=False)
if len(missing_keys) > 0:
for k in missing_keys:
print("Missing ", k)
if len(unexpected_keys) > 0:
for k in unexpected_keys:
print("Unexpected ", k)
print("\033[92m Successfully Loaded Pretrained GigaPath model from {} \033[00m".format(pretrained))
else:
print("\033[93m Pretrained weights not found at {}. Randomly initialized the model! \033[00m".format(local_path))
return model
@register_model
def gigapath_slide_enc2l512d(**kwargs):
model = LongNetViT(embed_dim=512, depth=2, mlp_ratio=4, norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16)
return model
@register_model
def gigapath_slide_enc1l512d_level0(**kwargs):
model = LongNetViT(embed_dim=512, depth=1, mlp_ratio=4, tile_size=1024,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16)
return model
@register_model
def gigapath_slide_enc1l512d_level1(**kwargs):
model = LongNetViT(embed_dim=512, depth=1, mlp_ratio=4, tile_size=2048,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16)
return model
@register_model
def gigapath_slide_enc1l512d_level2(**kwargs):
model = LongNetViT(embed_dim=512, depth=1, mlp_ratio=4, tile_size= 4096,norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs).to(torch.bfloat16)
return model