| |
| import os |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| import open_clip |
| import gigapath.slide_encoder as slide_encoder |
| |
| from torchvision import transforms |
| import timm |
| from transformers import AutoModelForCausalLM, BitsAndBytesConfig, GenerationConfig |
| from transformers import BertModel, BertConfig |
| from accelerate import Accelerator |
| from utils.utils import clip_path_map |
|
|
| class Attn_Net_Gated(nn.Module): |
| def __init__(self, L = 1024, D = 256, dropout = False, heads = 1): |
| super(Attn_Net_Gated, self).__init__() |
| self.attention_a = [ |
| nn.Linear(L, D), |
| nn.Tanh()] |
| |
| self.attention_b = [nn.Linear(L, D), |
| nn.Sigmoid()] |
| if dropout: |
| self.attention_a.append(nn.Dropout(0.25)) |
| self.attention_b.append(nn.Dropout(0.25)) |
|
|
| self.attention_a = nn.Sequential(*self.attention_a) |
| self.attention_b = nn.Sequential(*self.attention_b) |
| |
| self.attention_c = nn.Linear(D, heads) |
|
|
| def forward(self, x, mask): |
| a = self.attention_a(x) |
| b = self.attention_b(x) |
| A = a.mul(b) |
| A = self.attention_c(A) |
| A[mask==0] = 1e-9 |
| return A |
|
|
| class Reducer(nn.Module): |
| """Instruct Embedding""" |
|
|
| def __init__( |
| self, |
| in_chans=1536, |
| embed_dim=768, |
| norm_layer=None, |
| bias=True, |
| ): |
| super().__init__() |
|
|
| self.proj = nn.Linear(in_chans, embed_dim, bias=bias) |
| self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() |
|
|
| def forward(self, x): |
| B, L, D = x.shape |
| x = self.proj(x) |
| x = self.norm(x) |
| return x |
| |
| class AttentionLayer(nn.Module): |
| def __init__(self, llm_dim, embed_dim, num_layers=2, num_heads=16): |
| super(AttentionLayer, self).__init__() |
|
|
| print("#####Vision-Text Interation Qformer######") |
|
|
| self.num_layers = num_layers |
| |
| |
| config = BertConfig( |
| hidden_size=embed_dim, |
| num_attention_heads=num_heads, |
| num_hidden_layers=1 |
| ) |
| |
| self.reduce = nn.Sequential( |
| nn.Linear(llm_dim, embed_dim), |
| nn.ReLU() |
| ) |
|
|
| self.self_attention = nn.ModuleList([BertModel(config) for _ in range(num_layers)]) |
| self.cross_attention = nn.ModuleList([nn.MultiheadAttention(embed_dim=embed_dim, num_heads=num_heads, batch_first=True) for _ in range(num_layers)]) |
| |
| def forward(self, input_tensor, context_tensor, instruct_tensor, self_attention_mask=None, key_padding_mask=None): |
| """ |
| input_tensor: Input feature tensor for self-attention processing |
| context_tensor: Context feature tensor for cross-attention processing |
| instruct_tensor: Instruction tensor to be concatenated with input_tensor for self-attention |
| query_length: Length of the query to retain for cross-attention |
| self_attention_mask: Self-attention mask (1 indicates valid position, 0 indicates padding) |
| key_padding_mask: Cross-attention key padding mask (True -> no attention) |
| """ |
| query_length = input_tensor.shape[1] |
| instruct_tensor = self.reduce(instruct_tensor) |
|
|
| if self_attention_mask is not None: |
| |
| input_mask_extension = torch.ones((input_tensor.shape[0], input_tensor.shape[1]), dtype=self_attention_mask.dtype, device=self_attention_mask.device) |
| self_attention_mask = torch.cat((input_mask_extension, self_attention_mask), dim=-1) |
|
|
| for i in range(self.num_layers): |
| combined_tensor = torch.cat((input_tensor, instruct_tensor), dim=1) |
| self_attn_output = self.self_attention[i]( |
| inputs_embeds=combined_tensor, |
| attention_mask=self_attention_mask |
| ).last_hidden_state |
| query = self_attn_output[:, :query_length, :] |
| |
| input_tensor, _ = self.cross_attention[i](query=query, key=context_tensor, value=context_tensor, key_padding_mask=key_padding_mask) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| return input_tensor |
|
|
| class PPathVLM(nn.Module): |
| def __init__(self, llm_requires_grad, clip_name, load_in_8bit, load_in_4bit, llm_name, |
| trust_remote_code, token, tokenizer, image_token_id, data_cache_dir='~/.cache'): |
| nn.Module.__init__(self) |
|
|
| self.clip_name = clip_name |
| self.data_cache_dir = data_cache_dir |
|
|
| self.vision_encoder, self.image_processor, self.embed_dim = self.load_vision_encoder() |
|
|
| self.llm_tokenizer = tokenizer |
| self.llm = self.load_llm(load_in_8bit, load_in_4bit, llm_name, trust_remote_code, token) |
| self.embedding_layer = self.llm.get_input_embeddings() |
|
|
| |
| self.resampler_layer = nn.Sequential( |
| nn.LayerNorm(self.embed_dim), |
| nn.Linear(self.embed_dim, self.llm.config.hidden_size, bias=False), |
| nn.ReLU(), |
| nn.Dropout(0.25), |
| ) |
| |
| |
| self.config = self.llm.config |
| self.image_token_id = image_token_id |
|
|
| |
| for param in self.vision_encoder.parameters(): |
| param.requires_grad = False |
|
|
| |
| for param in self.llm.parameters(): |
| param.requires_grad = llm_requires_grad |
|
|
| |
| |
|
|
| def print_parameter_counts(self): |
| total_params = sum(p.numel() for p in self.parameters()) |
| trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad) |
| |
| print(f"Total number of parameters: {total_params}") |
| print(f"Number of trainable parameters: {trainable_params}") |
|
|
| def print_llm_parameters(self, num_params=5): |
| """Print a few parameters of the LLM to check requires_grad status.""" |
| count = 0 |
| for name, param in self.llm.named_parameters(): |
| if count >= num_params: |
| break |
| print(f"Parameter name: {name}") |
| print(f"Parameter requires_grad: {param.requires_grad}") |
| print(f"Parameter value: {param.data.flatten()[:5]}") |
| print("-" * 50) |
| count += 1 |
|
|
| def print_vision_parameters(self, num_params=5): |
| """Print a few parameters of the Vision Encoder to check requires_grad status.""" |
| count = 0 |
| for name, param in self.vision_encoder.visual.trunk.blocks.named_parameters(): |
| if count >= num_params: |
| break |
| print(f"Parameter name: {name}") |
| print(f"Parameter requires_grad: {param.requires_grad}") |
| print(f"Parameter value: {param.data.flatten()[:5]}") |
| print("-" * 50) |
| count += 1 |
|
|
| def _split_and_pad(self, image_embeds, p_num): |
| """ |
| Split the image_embeds tensor into k lists based on p_num and pad them to the max length. |
| Also generate the attention_mask. |
| |
| Args: |
| - image_embeds (torch.Tensor): The Nx512 tensor. |
| - p_num (list): List containing the number of embeddings for each segment. |
| |
| Returns: |
| - padded_lists (list): List of k tensors, each padded to the max length. |
| - attention_mask (torch.Tensor): Attention mask tensor with 1s for data positions and 0s for padding. |
| """ |
| |
| assert sum(p_num) == image_embeds.size(0), "p_num does not sum to the number of embeddings" |
| |
| |
| start = 0 |
| split_lists = [] |
| for num in p_num: |
| split_lists.append(image_embeds[start : start + num]) |
| start += num |
| |
| |
| max_length = max(p_num) |
| |
| |
| padded_lists = [] |
| attention_masks = [] |
| for tensor in split_lists: |
| length = tensor.size(0) |
| padding = max_length - length |
| padded_tensor = torch.cat([tensor, torch.zeros((padding, tensor.size(1))).to(tensor.device)], dim=0) |
| padded_lists.append(padded_tensor) |
| |
| |
| attention_mask = torch.cat([torch.ones(length), torch.zeros(padding)], dim=0).to(tensor.device) |
| attention_masks.append(attention_mask) |
| |
| |
| padded_tensors = torch.stack(padded_lists, dim=0) |
| attention_mask_tensor = torch.stack(attention_masks, dim=0).long() |
| |
| return padded_tensors, attention_mask_tensor |
|
|
| def load_vision_encoder(self): |
| print("vision_encoder loading ...") |
|
|
| clip_path = clip_path_map(self.clip_name) |
| if self.clip_name=="pathclip-base": |
| vision_encoder, _, image_processor = open_clip.create_model_and_transforms('ViT-B-16', pretrained=clip_path, force_quick_gelu=True) |
| embed_dim = 512 |
| vision_encoder.visual.output_tokens = True |
| elif self.clip_name=="conch": |
| from conch.open_clip_custom import create_model_from_pretrained |
| vision_encoder, image_processor = create_model_from_pretrained('conch_ViT-B-16', clip_path) |
| embed_dim = 512 |
| vision_encoder.visual.output_tokens = True |
| elif self.clip_name=="uni": |
| vision_encoder = timm.create_model("vit_large_patch16_224", img_size=224, patch_size=16, init_values=1e-5, num_classes=0, dynamic_img_size=True) |
| vision_encoder.load_state_dict(torch.load(clip_path, map_location="cpu"), strict=True) |
| def rgba_to_rgb(image): |
| return image.convert('RGB') |
| image_processor = transforms.Compose( |
| [ |
| transforms.Lambda(rgba_to_rgb), |
| transforms.Resize((224,224)), |
| transforms.ToTensor(), |
| transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
| ] |
| ) |
| embed_dim = 1024 |
| else: |
| raise Exception("wrong clip") |
| return vision_encoder, image_processor, embed_dim |
| |
| def load_llm(self, load_in_8bit, load_in_4bit, llm_name, trust_remote_code, token): |
| print("llm loading ...") |
| if load_in_8bit and load_in_4bit: |
| raise ValueError("You can't load the model in 8 bits and 4 bits at the same time") |
| elif load_in_8bit or load_in_4bit: |
| quantization_config = BitsAndBytesConfig( |
| load_in_8bit=load_in_8bit, load_in_4bit=load_in_4bit |
| ) |
| |
| device_map = {"": Accelerator().local_process_index} |
| torch_dtype = torch.bfloat16 |
| else: |
| device_map = None |
| quantization_config = None |
| torch_dtype = None |
| |
|
|
|
|
| llm = AutoModelForCausalLM.from_pretrained( |
| llm_name, |
| quantization_config=quantization_config, |
| device_map=device_map, |
| trust_remote_code=trust_remote_code, |
| torch_dtype=torch_dtype, |
| token=token, |
| use_cache= True, |
| cache_dir = self.data_cache_dir, |
| ) |
| llm.resize_token_embeddings(len(self.llm_tokenizer)) |
| return llm |
| |
| def generate(self, *args, **kwargs): |
| generation_config = GenerationConfig( |
| max_length=100, |
| temperature=1.0, |
| top_k=50, |
| top_p=1.0, |
| num_return_sequences=1, |
| repetition_penalty=1.0, |
| do_sample=True, |
| pad_token_id=self.llm_tokenizer.eos_token_id, |
| bos_token_id=self.llm_tokenizer.bos_token_id, |
| ) |
| |
| with torch.no_grad(): |
| image = kwargs["image"] |
| p_num = kwargs["patch_num"] |
| input_ids = kwargs["input_ids"].to(image.device) |
| attention_mask = kwargs["attention_mask"].to(image.device) |
|
|
| with torch.inference_mode(): |
| if self.clip_name == 'uni': |
| image_embeds = self.vision_encoder(image) |
| elif self.clip_name == 'conch': |
| image_embeds = self.vision_encoder.encode_image(image, normalize=False, proj_contrast=False) |
| else: |
| image_embeds = self.vision_encoder.encode_image(image, normalize=False)[0] |
|
|
| image_embeds, image_atts = self._split_and_pad(image_embeds, p_num) |
|
|
| image_embeds = image_embeds.to(image.device) |
| image_atts = image_atts.to(image.device) |
| attention_mask = torch.cat([image_atts, attention_mask], dim=1) |
| |
| fusion_embs = self.get_fusion_embedding(input_ids, image_embeds) |
| attention_mask = self.pad_attention_fusion(fusion_embs.size(1), attention_mask) |
| res = self.llm.generate(inputs_embeds=fusion_embs, attention_mask=attention_mask, generation_config=generation_config) |
|
|
| generate_list = [] |
| for item in res: |
| generation = self.llm_tokenizer.decode(item, skip_special_tokens=True) |
| generate_list.append(generation) |
| return generate_list |
| |
| def pad_attention_fusion(self, new_seq_len, need_pad_seq): |
| padd_len = new_seq_len - need_pad_seq.size(1) |
| bz = need_pad_seq.size(0) |
|
|
| generated_pad = torch.ones((bz, padd_len), dtype=need_pad_seq.dtype).to(need_pad_seq.device) |
|
|
| |
| |
| |
| paded_seq = torch.cat((generated_pad, need_pad_seq), dim=1) |
|
|
| return paded_seq |
| |
| def pad_label_fusion(self, new_seq_len, labels): |
| padd_len = new_seq_len - labels.size(1) |
| bz = labels.size(0) |
|
|
| generated_pad = torch.ones((bz, padd_len), dtype=labels.dtype).fill_(-100).to(labels.device) |
| paded_seq = torch.cat((generated_pad, labels), dim=1) |
|
|
| return paded_seq |
|
|
| def get_fusion_embedding(self, input_ids, image_embs): |
| token_embs = self.embedding_layer(input_ids) |
|
|
| mapped_image_embs = self.resampler_layer(image_embs) |
| |
| |
| |
| image_token_emb = self.embedding_layer(torch.tensor(self.image_token_id).to(mapped_image_embs.device)) |
| batch_image_token_emb = image_token_emb.repeat(mapped_image_embs.size(0), 1, 1) |
| fusion_embs = torch.cat((batch_image_token_emb, mapped_image_embs, token_embs), dim=1) |
| return fusion_embs |
| |
| def forward(self, *args, **kwargs): |
| image = kwargs["image"] |
| p_num = kwargs["patch_num"] |
| input_ids = kwargs["input_ids"].to(image.device) |
| attention_mask = kwargs["attention_mask"].to(image.device) |
| labels = kwargs["labels"].to(image.device) |
|
|
| with torch.inference_mode(): |
| if self.clip_name == 'uni': |
| image_embeds = self.vision_encoder(image) |
| elif self.clip_name == 'conch': |
| image_embeds = self.vision_encoder.encode_image(image, normalize=False, proj_contrast=False) |
| else: |
| image_embeds = self.vision_encoder.encode_image(image, normalize=False)[0] |
|
|
| image_embeds, image_atts = self._split_and_pad(image_embeds, p_num) |
|
|
| image_embeds = image_embeds.to(image.device).to(torch.bfloat16) |
| image_atts = image_atts.to(image.device) |
| attention_mask = torch.cat([image_atts, attention_mask], dim=1) |
| |
| fusion_embs = self.get_fusion_embedding(input_ids, image_embeds) |
| attention_mask = self.pad_attention_fusion(fusion_embs.size(1), attention_mask) |
| labels = self.pad_label_fusion(fusion_embs.size(1), labels) |
|
|
| output = self.llm(inputs_embeds=fusion_embs, attention_mask=attention_mask, labels=labels) |
| return output |
| |
| class WPathVLM(PPathVLM): |
| def __init__(self, llm_requires_grad, load_in_8bit, load_in_4bit, llm_name, |
| trust_remote_code, token, tokenizer, image_token_id, |
| n_heads='32,16,8', n_level=3, embed_dim=512, |
| agg_strategy='abmil', hierachical_token=True, hierachical_adaptor=True, |
| data_cache_dir = '~/.cache'): |
| |
| nn.Module.__init__(self) |
|
|
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
| self.llm_tokenizer = tokenizer |
| self.data_cache_dir = data_cache_dir |
| self.llm = self.load_llm(load_in_8bit, load_in_4bit, llm_name, trust_remote_code, token) |
| self.embedding_layer = self.llm.get_input_embeddings() |
| self.n_heads = [int(n_head) for n_head in n_heads.split(',')] |
| self.n_level = n_level |
| self.agg_strategy = agg_strategy |
| self.embed_dim = embed_dim |
| self.config = self.llm.config |
| self.image_token_id = image_token_id |
| self.hierachical_token = hierachical_token |
| self.hierachical_adaptor = hierachical_adaptor |
|
|
| self.reduce = Reducer(self.llm.config.hidden_size, embed_dim) |
|
|
| if self.hierachical_adaptor: |
| self.adaptor_level = n_level |
| else: |
| self.adaptor_level = 1 |
|
|
| if self.agg_strategy == 'abmil': |
| size = [embed_dim, int(embed_dim/2)] |
| |
| self.att_net = nn.ModuleList([ |
| Attn_Net_Gated(L=size[0], D=size[1], dropout=True, heads=self.n_heads[i]) |
| for i in range(self.n_level) |
| ]) |
|
|
| |
| if self.agg_strategy == "longnet": |
| self.query = [nn.Parameter(torch.zeros(1, self.n_heads[i], embed_dim)) for i in range(self.n_level)] |
| self.longnet_encoder_list = nn.ModuleList([ |
| slide_encoder.create_model(pretrained="", model_arch=f"gigapath_slide_enc1l512d_level{level}", in_chans=self.llm.config.hidden_size) |
| for level in range(self.adaptor_level)]).to(torch.bfloat16) |
|
|
| |
| if self.agg_strategy == "qformer": |
| self.query = [nn.Parameter(torch.zeros(1, self.n_heads[i], embed_dim)) for i in range(self.n_level)] |
| self.qformer_encoder_list = nn.ModuleList([AttentionLayer(self.llm.config.hidden_size, embed_dim) for _ in range(self.adaptor_level)]).to(torch.bfloat16) |
|
|
| if self.agg_strategy in ["qformer", "longnet"]: |
| for tensor in self.query: |
| nn.init.xavier_uniform_(tensor) |
|
|
| self.resampler_layer = nn.Sequential( |
| nn.LayerNorm(self.embed_dim), |
| nn.Linear(self.embed_dim, self.llm.config.hidden_size, bias=False), |
| nn.ReLU(), |
| nn.Dropout(0.25), |
| ) |
|
|
| |
| for param in self.llm.parameters(): |
| param.requires_grad = llm_requires_grad |
|
|
| def get_wsi_embedding(self, patch_embs, patch_masks, coords, level, input_ids_instruct=None, attention_mask_instruct=None): |
|
|
| if self.hierachical_adaptor: |
| adaptor_level = level |
| else: |
| adaptor_level = 0 |
|
|
| |
|
|
| if self.agg_strategy == 'abmil': |
|
|
| batch_size, num_patches, embedding_size = patch_embs.shape |
| |
| patch_embs_flattened = patch_embs.view(batch_size * num_patches, embedding_size) |
| patch_masks_flattened = patch_masks.view(batch_size * num_patches) |
|
|
| patch_attention_matrices = self.att_net[level](patch_embs_flattened, patch_masks_flattened) |
| patch_attention_matrices = patch_attention_matrices.view(batch_size, num_patches, self.n_heads[level]) |
| patch_attention_matrices = F.softmax(patch_attention_matrices, dim=1) |
| |
| mapped_patch_embs = patch_embs_flattened.view(batch_size, num_patches, embedding_size) |
| agged_WSI_embs = patch_attention_matrices.unsqueeze(-1) * mapped_patch_embs.unsqueeze(-2) |
| agged_WSI_embs = torch.sum(agged_WSI_embs, dim=1) |
|
|
| elif self.agg_strategy == "longnet": |
| patch_size_dict = {0:1024.0, 1:2048.0, 2:4096.0} |
| patch_size_dict = {key: torch.tensor([value]) for key, value in patch_size_dict.items()} |
| query = self.query[level].repeat(patch_embs.shape[0], 1, 1) |
| key_padding_mask = patch_masks.bool() |
| instruct_embs = self.embedding_layer(input_ids_instruct) |
| instruct_embs = self.reduce(instruct_embs) |
|
|
| agged_WSI_embs = self.run_inference_with_slide_encoder(query.cuda().to(torch.bfloat16), patch_embs.to(torch.bfloat16), |
| instruct_embs, coords.to(torch.bfloat16), |
| self_attention_mask=attention_mask_instruct, |
| key_padding_mask=~key_padding_mask, |
| slide_encoder_model=self.longnet_encoder_list[adaptor_level], |
| patch_size=patch_size_dict[level].cuda().to(torch.bfloat16)) |
|
|
| elif self.agg_strategy == "qformer": |
| query = self.query[level].repeat(patch_embs.shape[0], 1, 1) |
| key_padding_mask = patch_masks.bool() |
| instruct_embs = self.embedding_layer(input_ids_instruct) |
| agged_WSI_embs = self.qformer_encoder_list[adaptor_level](query.cuda().to(torch.bfloat16), patch_embs.to(torch.bfloat16), |
| instruct_embs.to(torch.bfloat16), |
| self_attention_mask=attention_mask_instruct, |
| key_padding_mask=~key_padding_mask) |
| |
| else: |
| agged_WSI_embs = patch_embs |
|
|
| return agged_WSI_embs |
| |
| |
| def run_inference_with_slide_encoder(self, query, tile_embeds, instruct_embs, coords, slide_encoder_model, patch_size=256, |
| self_attention_mask=None, key_padding_mask=None): |
| """ |
| Run inference with the slide encoder |
| |
| Arguments: |
| ---------- |
| tile_embeds : torch.Tensor |
| Tile embeddings |
| coords : torch.Tensor |
| Coordinates of the tiles |
| slide_encoder_model : torch.nn.Module |
| Slide encoder model |
| """ |
| if len(tile_embeds.shape) == 2: |
| tile_embeds = tile_embeds.unsqueeze(0) |
| coords = coords.unsqueeze(0) |
| |
| |
| |
| |
| |
| |
| |
| |
| with torch.cuda.amp.autocast(dtype=torch.float16): |
| |
| slide_embeds = slide_encoder_model(query, tile_embeds, instruct_embs, coords, patch_size, |
| self_attention_mask=self_attention_mask, key_padding_mask=key_padding_mask) |
| |
| |
| |
| return slide_embeds |
|
|
| def get_fusion_embedding(self, input_ids, agged_WSI_embs): |
| |
| token_embs = self.embedding_layer(input_ids) |
| image_token_emb = self.embedding_layer(torch.tensor(self.image_token_id).to(agged_WSI_embs[0].device)) |
| batch_image_token_emb = image_token_emb.repeat(agged_WSI_embs[0].size(0), 1, 1) |
|
|
| agged_WSI_embs = torch.cat(agged_WSI_embs, dim=1) |
| agged_WSI_embs = self.resampler_layer(agged_WSI_embs) |
| |
| fusion_embs = torch.cat((batch_image_token_emb, agged_WSI_embs, token_embs), dim=1) |
| return fusion_embs |
|
|
| def get_fusion_embedding_hierarchy(self, input_ids, agged_WSI_embs): |
| |
| text_token_embs = self.embedding_layer(input_ids) |
| image_token_emb = self.embedding_layer(torch.tensor(self.image_token_id[0]).to(agged_WSI_embs[0].device)) |
| batch_image_token_emb = image_token_emb.repeat(agged_WSI_embs[0].size(0), 1, 1) |
|
|
| mag_WSI_embs = [] |
|
|
| for i in range(self.n_level): |
| mag_token_emb = self.embedding_layer(torch.tensor(self.image_token_id[i+1]).to(agged_WSI_embs[0].device)) |
| batch_mag_token_emb = mag_token_emb.repeat(agged_WSI_embs[0].size(0), 1, 1) |
| agged_WSI_emb = self.resampler_layer(agged_WSI_embs[i]) |
| mag_WSI_emb = torch.cat((batch_mag_token_emb, agged_WSI_emb), dim=1) |
| mag_WSI_embs.append(mag_WSI_emb) |
|
|
| |
| |
|
|
| mag_WSI_embs = torch.cat(mag_WSI_embs, dim=1) |
| |
| fusion_embs = torch.cat((batch_image_token_emb, mag_WSI_embs, text_token_embs), dim=1) |
| return fusion_embs |
|
|
| def generate(self, *args, **kwargs): |
| generation_config = GenerationConfig( |
| max_length=200, |
| temperature=1.0, |
| top_k=50, |
| top_p=0.95, |
| num_return_sequences=1, |
| repetition_penalty=1.1, |
| do_sample=True, |
| pad_token_id=self.llm_tokenizer.eos_token_id, |
| bos_token_id=self.llm_tokenizer.bos_token_id, |
| ) |
| |
| with torch.no_grad(): |
| input_ids = kwargs["input_ids"] |
| text_attention_mask = kwargs["attention_mask"] |
| input_ids_instruct = kwargs["input_ids_instruct"] |
| attention_mask_instruct = kwargs["attention_mask_instruct"] |
| |
| agged_WSI_embs = [] |
| |
| for level in range(self.n_level): |
| patch_embs = kwargs["fea{}".format(level)].float() |
| patch_attention_mask = kwargs["mask{}".format(level)] |
| coords = kwargs["cor{}".format(level)] |
| agged_WSI_embs_level = self.get_wsi_embedding(patch_embs, patch_attention_mask, coords, level, |
| input_ids_instruct = input_ids_instruct, |
| attention_mask_instruct = attention_mask_instruct) |
| |
| agged_WSI_embs.append(agged_WSI_embs_level.float()) |
|
|
| if self.hierachical_token: |
| fusion_embs = self.get_fusion_embedding_hierarchy(input_ids, agged_WSI_embs) |
| else: |
| fusion_embs = self.get_fusion_embedding(input_ids, agged_WSI_embs) |
| text_attention_mask = self.pad_attention_fusion(fusion_embs.size(1), text_attention_mask) |
|
|
| res = self.llm.generate(inputs_embeds=fusion_embs, attention_mask=text_attention_mask, generation_config=generation_config) |
|
|
| generate_list = [] |
| for item in res: |
| generation = self.llm_tokenizer.decode(item, skip_special_tokens=True) |
| generate_list.append(generation) |
| return generate_list |
| |
| def forward(self, *args, **kwargs): |
| input_ids = kwargs["input_ids"] |
| text_attention_mask = kwargs["attention_mask"] |
| input_ids_instruct = kwargs["input_ids_instruct"] |
| attention_mask_instruct = kwargs["attention_mask_instruct"] |
| labels = kwargs["labels"] |
| agged_WSI_embs = [] |
| |
| for level in range(self.n_level): |
| patch_embs = kwargs["fea{}".format(level)] |
| patch_attention_mask = kwargs["mask{}".format(level)] |
| coords = kwargs["cor{}".format(level)] |
| agged_WSI_embs_level = self.get_wsi_embedding(patch_embs, patch_attention_mask, coords, level, |
| input_ids_instruct = input_ids_instruct, |
| attention_mask_instruct = attention_mask_instruct) |
| |
| agged_WSI_embs.append(agged_WSI_embs_level) |
| |
| if self.hierachical_token: |
| fusion_embs = self.get_fusion_embedding_hierarchy(input_ids, agged_WSI_embs) |
| else: |
| fusion_embs = self.get_fusion_embedding(input_ids, agged_WSI_embs) |
|
|
| text_attention_mask = self.pad_attention_fusion(fusion_embs.size(1), text_attention_mask) |
| labels = self.pad_label_fusion(fusion_embs.size(1), labels) |
|
|
| output = self.llm(inputs_embeds=fusion_embs, attention_mask=text_attention_mask, labels=labels) |
| return output |