diff --git "a/minigpt4/models/minigpt_base.py" "b/minigpt4/models/minigpt_base.py" --- "a/minigpt4/models/minigpt_base.py" +++ "b/minigpt4/models/minigpt_base.py" @@ -4,30 +4,26 @@ import random import torch from torch.cuda.amp import autocast as autocast import torch.nn as nn -import open_clip -from minigpt4.common.registry import registry -from minigpt4.models.base_model import BaseModel -from transformers import StoppingCriteria, StoppingCriteriaList,AutoTokenizer, AutoModel -from minigpt4.models.base_model import LayerNorm -from minigpt4.conversation.conversation import StoppingCriteriaSub -from minigpt4.models.eva_vit import Block -from minigpt4.models.mae_vit import mae_vit_large_patch16 -# from transformers import -import math import torch.nn.functional as F from einops import rearrange from einops.layers.torch import Rearrange -from open_clip.factory import HF_HUB_PREFIX, _MODEL_CONFIGS -import pdb + +from minigpt4.common.registry import registry +from minigpt4.models.base_model import disabled_train +from minigpt4.models.minigpt_base import MiniGPTBase +from minigpt4.models.Qformer import BertConfig, BertLMHeadModel + + + + class SpatialPoolingProjector(nn.Module): - def __init__(self, image_size=(32, 256, 256), patch_size=(8, 16, 16), in_dim=768, out_dim=768, layer_type='linear', layer_num=6, pooling_type='spatial', pooling_size=2): - # print(in_dim,out_dim) + def __init__(self, image_size=(32, 256, 256), patch_size=(4, 16, 16), in_dim=768, out_dim=768, layer_type='linear', layer_num=6, pooling_type='spatial', pooling_size=4): super().__init__() self.in_dim = in_dim self.pooling_size = pooling_size - self.num_patches_pre = [img // pch for img, pch in zip(image_size, patch_size)] + self.num_patches_pre = [img // pch for img, pch in zip(image_size, patch_size)]#8,16,16 #8,16,16 self.num_patches_post = [num // pooling_size for num in self.num_patches_pre] #4,8,8 @@ -50,7 +46,7 @@ class SpatialPoolingProjector(nn.Module): self.pooling_type = pooling_type def forward(self, x): - B = x.shape[0] # B*N*D + B = x.shape[0] # B*N*D 1*2048*768 # print(x.shape) if self.pooling_type == 'spatial': to_3d = Rearrange("b (p1 p2 p3) d -> b d p1 p2 p3", b=B, d=self.in_dim, p1=self.num_patches_pre[0], p2=self.num_patches_pre[1], p3=self.num_patches_pre[2]) @@ -64,7 +60,6 @@ class SpatialPoolingProjector(nn.Module): x = x.permute(0, 2, 1) #b n d x = rearrange(x, "b n d -> (b n) d") - # print(x.shape) x = self.projector(x) x = rearrange(x, "(b n) d -> b n d", b=B) @@ -79,1928 +74,551 @@ class SpatialPoolingProjector(nn.Module): -def calculate_iou(box1,box2): - y1max, x1max, y1min, x1min = box1[0], box1[1], box1[2], box1[3] - y2max, x2max, y2min, x2min = box2[0], box2[1], box2[2], box2[3] - #计算两个框的面积 - s1 = (y1max - y1min + 1.) * (x1max - x1min + 1.) - s2 = (y2max - y2min + 1.) * (x2max - x2min + 1.) - - #计算相交部分的坐标 - xmin = max(x1min,x2min) - ymin = max(y1min,y2min) - xmax = min(x1max,x2max) - ymax = min(y1max,y2max) - - inter_h = max(ymax - ymin + 1, 0) - inter_w = max(xmax - xmin + 1, 0) - - intersection = inter_h * inter_w - union = s1 + s2 - intersection - #计算iou - iou = intersection / union - return iou - -def calculate_giou_loss(ground_truth, predicted): +@registry.register_model("minigpt4") +class MiniGPT4(MiniGPTBase): """ - 计算GIOU损失 + MiniGPT-4 model """ - iou = calculate_iou(ground_truth, predicted) - - # 计算GIOU的其他三个项 - enclose_y1 = torch.min(ground_truth[:, 0], predicted[:, 0]) - enclose_x1 = torch.min(ground_truth[:, 1], predicted[:, 1]) - enclose_y2 = torch.max(ground_truth[:, 2], predicted[:, 2]) - enclose_x2 = torch.max(ground_truth[:, 3], predicted[:, 3]) - - enclose_area = torch.clamp(enclose_x2 - enclose_x1, min=0) * torch.clamp(enclose_y2 - enclose_y1, min=0) - print(iou) - giou = iou - (enclose_area - calculate_iou(ground_truth, predicted)) / (enclose_area + 1e-10) - - giou_loss = 1.0 - giou - return giou_loss.mean() - - -class mulit_modality_Attention(nn.Module): - def __init__(self, in_channels=16, out_channels=8,emb_dim=4096, output_dim=4096,att_dropout=0.0, aropout=0.0): - super(mulit_modality_Attention, self).__init__() - self.emb_dim = emb_dim - self.scale = emb_dim ** -0.5 - - # self.proj_in = nn.Linear(in_channels, emb_dim, kernel_size=1, stride=1, padding=0) - - self.Wq = nn.Linear(emb_dim, emb_dim) - self.Wk = nn.Linear(emb_dim, emb_dim) - self.Wv = nn.Linear(emb_dim, emb_dim) - - self.softmax = nn.Softmax(dim=-1) - self.output_linear=nn.Linear(emb_dim,emb_dim) - self.output_proj = nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1)) - - - def forward(self, x_list, context=None, pad_mask=None): - ''' - - :param x: [1-4, 32, 4096] - :param context: [batch_szie, seq_len, emb_dim] - :param pad_mask: [batch_size, seq_len, seq_len] - :return: - ''' - b, h, w = x_list.shape - - # query_list=[] - # key_list=[] - # value_list=[] - # for i in range(b): - # query_list.append(self.Wq(x_list[i,:,:])) - # key_list.append(self.Wk(x_list[i,:,:])) - # value_list.append(self.Wv(x_list[i,:,:])) - # print(query_list[0].shape) - - query_list=self.Wq(x_list) - key_list=self.Wk(x_list) - value_list=self.Wv(x_list) - # attention_score=torch.einsum('bid,bjd->bij',query_list,key_list) - q_c,q_w,q_h=query_list.shape - - # query_list=query_list.reshape(q_c,-1,q_w,q_h) - # key_list=key_list.reshape(-1,q_c,q_w,q_h) - - # print("shape1",query_list.shape) - # print('shape2',key_list.shape) - # attention_score=torch.einsum('bij') - attention_score=torch.matmul(query_list,key_list.transpose(-1,-2)) - attention_score=self.softmax(attention_score*self.scale) - # print(attention_score.shape) - - attention_score=attention_score.reshape(q_c,-1,q_w,q_w) - value_list=value_list.reshape(-1,q_c,q_w,q_h) - - output=torch.matmul(attention_score,value_list) - output=self.output_linear(output) - o_b,o_c,o_w,o_d=output.shape - output=output.reshape(1,o_b*o_c,o_w,o_d) - # print(output.shape) - if o_b*o_c==16: - output=self.output_proj(output) - output=output.reshape(1,-1,o_w,o_d) - - - # # [batch_size, h*w, seq_len] - # att_weights = torch.einsum('bid,bjd -> bij', Q, K) - # att_weights = att_weights * self.scale - - # if pad_mask is not None: - # # [batch_size, h*w, seq_len] - # att_weights = att_weights.masked_fill(pad_mask, -1e9) - - # att_weights = F.softmax(att_weights, dim=-1) - # out = torch.einsum('bij, bjd -> bid', att_weights, V) # [batch_size, h*w, emb_dim] - - # out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w) # [batch_size, c, h, w] - # out = self.proj_out(out) # [batch_size, c, h, w] - - # print(out.shape) - - return output - -# 以前出现Nan的cross att版本 -# class resolution_attention(nn.Module): - - -# def __init__(self, in_channels=16, out_channels=8,emb_dim=768, output_dim=768,att_dropout=0.0, aropout=0.0): -# super(resolution_attention, self).__init__() -# self.emb_dim = emb_dim -# self.scale = emb_dim ** -0.5 - -# # self.proj_in = nn.Linear(in_channels, emb_dim, kernel_size=1, stride=1, padding=0) -# # self.batch_normal_layer_hr=nn.BatchNorm1d(1) -# # self.batch_normal_layer_lr=nn.BatchNorm1d(1) -# self.Wq = nn.Linear(emb_dim, emb_dim) -# self.Wk = nn.Linear(emb_dim, emb_dim) -# self.Wv = nn.Linear(emb_dim, emb_dim) -# self.bn_layer1=nn.BatchNorm1d(emb_dim) -# self.bn_layer2=nn.BatchNorm1d(emb_dim) -# self.softmax = nn.Softmax(dim=-1) -# self.output_linear=nn.Linear(emb_dim,emb_dim) -# self.relu_fn=nn.ReLU() -# # self.output_proj = nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1)) -# self.drop_out=nn.Dropout(0.) - - - -# def forward(self, HR_image, LR_image, context=None, pad_mask=None):#这里命名反了,HR实际上是lr,lr是hr -# ''' - -# :param x: [1-4, 32, 4096] -# :param context: [batch_szie, seq_len, emb_dim] -# :param pad_mask: [batch_size, seq_len, seq_len] -# :return: -# ''' -# #batch size is 1 -# HR_image=self.bn_layer1(HR_image.reshape(-1,HR_image.size(-1))).reshape(1,-1,HR_image.size(-1)) # 1, 128, 4096 -# LR_image=self.bn_layer2(LR_image.reshape(-1,HR_image.size(-1))).reshape(1,-1,HR_image.size(-1)) # 1, 512, 4096 - -# query_list=self.Wq(HR_image) # 1, 128, 4096 -# key_list=self.Wk(LR_image) # 1, 512, 4096 -# value_list=self.Wv(LR_image) # 1, 512, 4096 - - -# attention_score=torch.matmul(query_list,key_list.transpose(-1,-2)) # 1, 128, 512 -# attention_score=self.softmax(attention_score*self.scale) # 1, 128, 512 - -# output=torch.matmul(attention_score,value_list) # 1, 128, 4096 -# output=self.output_linear(output) + HR_image # 1, 128, 4096 - - -# return output - - -def attention(query, key, value, mask=None, dropout=None): - d_k = query.size(-1) # 768 - scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # 1,5,32,1,784 vs 1,5,32,1,4 vs 1, 160, 512 - if mask is not None: - scores = scores.masked_fill(mask == 0, -1e9) - p_attn = F.softmax(scores, dim=-1) - if dropout is not None: - p_attn = dropout(p_attn) - return torch.matmul(p_attn, value), p_attn - - -class resolution_attention(nn.Module): - - def __init__(self, in_channels=16, out_channels=8, emb_dim=768, output_dim=768, dropout=0.1, aropout=0.0): - super(resolution_attention, self).__init__() - self.emb_dim = emb_dim - self.Wq = nn.Linear(emb_dim, emb_dim) - self.Wk = nn.Linear(emb_dim, emb_dim) - self.Wv = nn.Linear(emb_dim, emb_dim) - self.attn = None - self.output_linear = nn.Linear(emb_dim,emb_dim) - self.dropout = nn.Dropout(p=dropout) # 注意力得分后的 Dropout - self.dropout_2 = nn.Dropout(p=dropout) # 残差连接后的 Dropout - self.norm = nn.LayerNorm(emb_dim) # 残差连接后的 LayerNorm - - def forward(self, LR_image, HR_image, context=None, mask=None): - ''' - :param x: [1-4, 32, 768] - :param context: [batch_szie, seq_len, emb_dim] - :param pad_mask: [batch_size, seq_len, seq_len] - :return: - ''' - version = 2 # or 1 2 - if version == 0: - # previous version 全部图像的LR看全部图像的HR,这个时候如果不同样本模态不一样的话,可能会出现问题 - LR_image = LR_image.reshape(-1,LR_image.size(-1)).reshape(1,-1,LR_image.size(-1)) # 1, 128, 768 - HR_image = HR_image.reshape(-1,HR_image.size(-1)).reshape(1,-1,HR_image.size(-1)) # 1, 512, 768 - elif version == 1: - # 修改version 每张图像的LR看HR - LR_image = LR_image.view(1, -1, 32, 1, LR_image.size(-1)) - HR_image = HR_image.view(1, -1, 32, 4, HR_image.size(-1)) # HR中有4倍的图像数量,即32*4 - elif version == 2: - # breakpoint() - # print(f"LR_image.shape:{LR_image.shape},HR_image.shape:{HR_image.shape}") - # 修改 ViT HR 196*224*4, 1, 768 LR 224, 1, 768 - LR_image = LR_image.view(1, -1, 32, 1, LR_image.size(-1)) - HR_image = HR_image.view(1, -1, 32, 4*196, HR_image.size(-1)) # HR中有4倍的图像数量,外加196个visual token,即32*4*196 - query_list=self.Wq(LR_image) # version3: 1,7,32,1,768 version2: 1,5,32,1,768 version1: 1, 160, 768 - key_list=self.Wk(HR_image) # version3: 1,7,32,784,768 version2: 1,5,32,4,768 version1: 1, 640, 768 - value_list=self.Wv(HR_image) # version3: 1,7,32,784,768 version2: 1,5,32,4,768 version1: 1, 640, 768 - x, self.attn = attention(query_list, key_list, value_list, mask=mask, dropout=self.dropout) # 1,5,32,1,768 - if version in [1, 2]: - # 把维度变换一下到 1, _, 768 - x = x.view(1,-1,LR_image.size(-1)) - query_list = query_list.view(1,-1,LR_image.size(-1)) - - x = self.output_linear(x) # [1, 128, 768] - # 使用残差连接 采用Transformer的技术细节 - x = self.norm(query_list + self.dropout_2(x)) # [1, 128, 768] - return x - - -def bbox_giou_loss(pred_bboxes, target_bboxes): - """ - Compute the Generalized IoU (GIOU) loss between predicted bounding boxes and target bounding boxes. - - Args: - - pred_bboxes: Predicted bounding boxes, tensor of shape (N, 4), where N is the number of bounding boxes. - Each bounding box is represented as [ymax, xmax, ymin, xmin]. - - target_bboxes: Target bounding boxes, tensor of shape (N, 4), where N is the number of bounding boxes. - Each bounding box is represented as [ymax, xmax, ymin, xmin]. - - Returns: - - giou_loss: Computed GIOU loss. - """ - x_pred_max=torch.max(pred_bboxes[:,1],pred_bboxes[:,3]) - x_pred_min=torch.min(pred_bboxes[:,1],pred_bboxes[:,3]) - y_pred_min=torch.min(pred_bboxes[:,0],pred_bboxes[:,2]) - y_pred_max=torch.max(pred_bboxes[:,0],pred_bboxes[:,2]) - - x_gt_max=torch.max(target_bboxes[:,1],target_bboxes[:,3]) - x_gt_min=torch.min(target_bboxes[:,1],target_bboxes[:,3]) - y_gt_min=torch.min(target_bboxes[:,0],target_bboxes[:,2]) - y_gt_max=torch.max(target_bboxes[:,0],target_bboxes[:,2]) - - - ymin = torch.max(y_pred_min, y_gt_min) - xmin = torch.max(x_pred_min, x_gt_min) - ymax = torch.min(y_pred_max, y_gt_max) - xmax = torch.min(x_pred_max, x_gt_max) - - intersection = torch.clamp((xmax - xmin), min=0) * torch.clamp((ymax - ymin), min=0) - # Calculate union - pred_area = (x_pred_max - x_pred_min) * (y_pred_max - y_pred_min) - target_area = (x_gt_max - x_gt_min) * (y_gt_max - y_gt_min) - union = pred_area + target_area - intersection - - - iou=intersection / union - # Calculate loss - # giou_loss = 1 - giou - # print(iou) - giou_loss=1-iou - return giou_loss.mean() - - - - - -class MiniGPTBase(BaseModel): - """ - Base class for MiniGPT-4 and MiniGPT-v2 - """ + PRETRAINED_MODEL_CONFIG_DICT = { + "pretrain_vicuna0": "configs/models/minigpt4_vicuna0.yaml", + "pretrain_llama2": "configs/models/minigpt4_llama2.yaml", + } def __init__( - self, - vit_model="eva_clip_g", - img_size=224, - drop_path_rate=0, - use_grad_checkpoint=False, - vit_precision="fp16", - freeze_vit=False, - llama_model="", - max_txt_len=32, - max_context_len=3800, - prompt_template="", - end_sym='\n', - low_resource=False, # use 8 bit and put vit in cpu - device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. - lora_r=8, # or 1 lora_r means lora is not used #lora_r=8,lora_a=32 - lora_target_modules=["q_proj", "v_proj"], - lora_alpha=32, # or 64 128/256 - lora_dropout=0.05, - modality_number=5, - model_2d_or_3d="3d", - self_training=False + self, + vit_model="eva_clip_g", + q_former_model="https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth", + img_size=224, + drop_path_rate=0, + use_grad_checkpoint=False, + vit_precision="fp16", + freeze_vit=False, + has_qformer=True, + freeze_qformer=True, + num_query_token=32, + llama_model="", + prompt_path="", + prompt_template="", + max_txt_len=32, + end_sym='\n', + low_resource=False, # use 8 bit and put vit in cpu + device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore. + modality_number=1, + model_2d_or_3d="2d", + self_training=False ): - super().__init__() + super().__init__( + vit_model=vit_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + llama_model=llama_model, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + modality_number=modality_number, + model_2d_or_3d=model_2d_or_3d, + self_training=self_training + ) + if self.model_2d_or_3d=='2d': + self.v_q_project=nn.Linear(self.visual_encoder_list[0].num_features,1408) + elif self.model_2d_or_3d=='3d': + self.v_q_project=nn.Linear(self.visual_encoder_list[0].hidden_size,1408) + self.self_training=self_training - freeze_vit=not self_training - self.model_2d_or_3d=model_2d_or_3d - print('self.model_2d_or_3d',self.model_2d_or_3d) - print('self_training is ',self_training) - if self.self_training==False: - self.llama_model, self.llama_tokenizer = self.init_llm( - llama_model_path=llama_model, - low_resource=low_resource, - low_res_device=device_8bit, - lora_r=lora_r, - lora_target_modules=lora_target_modules, - lora_alpha=lora_alpha, - lora_dropout=lora_dropout, - ) - self.visual_encoder_list=[] - self.ln_vision_list=[] self.modality_number=modality_number + self.modalities={"t1n":0,"t1c":0,"t2f":1,"t2w":1,"CT":2,'other':2} + self.has_qformer = False - # 加速一下,先把模型挂载到GPU上,不然太慢 - self.llama_model = self.llama_model.cuda() - - # 2024.09.07 Yanzhaoshi修改: add_special_tokens 给tokenizer加上特殊字符,特殊字符应作为一个整体 - self.llama_tokenizer.add_special_tokens({'additional_special_tokens':["", "", "", ""]}) - # self.llama_tokenizer.add_special_tokens({'additional_special_tokens':[f'' for i in range(12)]}) - # If there is a mismatch between tokenizer vocab size and embedding matrix, self.llama_model.base_model.model.model - # throw a warning and then expand the embedding matrix - if len(self.llama_tokenizer) > self.llama_model.get_input_embeddings().weight.shape[0] and lora_r>0: - print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.") - self.llama_model.base_model.model.model.resize_token_embeddings(new_num_tokens = len(self.llama_tokenizer)) - else: - self.llama_model.base_model.resize_token_embeddings(new_num_tokens = len(self.llama_tokenizer)) - # self.llama_tokenizer.add_special_tokens({'additional_special_tokens':[""]}) - # breakpoint() - self.start_visual_token_idx = None # 记录下第一个视觉token的位置 - self.end_visual_token_idx = None # 记录下最后一个视觉token的位置 - - # self.llama_model.resize_token_embeddings(len(self.llama_tokenizer)) - # box_token_id = self.llama_tokenizer.convert_tokens_to_ids('') - # BOX_TOKEN='' - # self.llama_tokenizer.add_tokens(["", ""], special_tokens=True) - self.bounding_box_label = self.llama_tokenizer.convert_tokens_to_ids("") - - # Langerhans cell histiocytosis - # print('meningioma',self.llama_tokenizer.convert_tokens_to_ids("meningioma")) - # print('self.bounding_box_label',self.bounding_box_label) - # self.location_label=self.llama_tokenizer.convert_ids_to_tokens([2287, 317, 351, 986, 284, 29892, 4423, 29901, 29871]) - # self.class_layer=nn.Linear(32*4096,8).to(self.device) - # self.boundinb_box_relu=nn.ReLU(inplace=True) - self.cross_attention=resolution_attention() - self.L1_loss_fn=nn.SmoothL1Loss() - # if self.model_2d_or_3d=='2d': - # self.spatial_pooling_layer_2d=SpatialPoolingProjector(in_dim=4096,out_dim=4096) - - # self.conv_layer=nn.Conv3d(32,16,kernel_size=3,stride=1,padding=1) - # self.conv3d_layer=nn.Sequential( - # nn.Conv3d(32,16,kernel_size=3,stride=1,padding=1), - # nn.Conv3d(16,8,kernel_size=3,stride=1,padding=1) - # ) - # self.conv3d_proj=nn.Linear(4096,4096) - # self.bounding_box_layer=nn.Linear(4*32*4096,4).to(self.device) - # decoder_embed_dim=512 - # decoder_num_heads=16 - # self.decoder_linear=nn.Linear(4096,512).to(self.device) - # self.decoder_blocks = nn.ModuleList([ - # Block(decoder_embed_dim, decoder_num_heads, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm) - # for i in range(1)]) - - # self.decoder_norm = nn.LayerNorm(decoder_embed_dim) - # self.decoder_pred = nn.Linear(decoder_embed_dim*128, 4, bias=True) - - for i in range(self.modality_number): - # print(freeze_vit) - # self.visual_encoder,self.ln_vision = self.init_vision_encoder( - # vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit,self.self_training,i - # ) - if self.model_2d_or_3d=='2d': - - # import json - - # with open("/grp01/saas_medml/yinong/ckpt/models--microsoft--BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/snapshots/f658b0b9eeca714b891d004b2d714c7ad3705170/open_clip_config.json", "r") as f: - # config = json.load(f) - # model_cfg = config["model_cfg"] - # preprocess_cfg = config["preprocess_cfg"] - # # 提取的是CLS特征 - # model_name='biomedclip_local' - # if (not model_name.startswith(HF_HUB_PREFIX) - # and model_name not in _MODEL_CONFIGS - # and config is not None): - # _MODEL_CONFIGS[model_name] = model_cfg - - # # tokenizer = get_tokenizer(model_name) - - # model, _, _ = open_clip.create_model_and_transforms(model_name=model_name,pretrained='/grp01/saas_medml/yinong/ckpt/models--microsoft--BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/snapshots/f658b0b9eeca714b891d004b2d714c7ad3705170/open_clip_pytorch_model.bin',**{f"image_{k}": v for k, v in preprocess_cfg.items()}) - - # 提取的是CLS特征 - model, _, _ = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') - ''' - 2024.09.02 关于视觉encoder的设置 - # 如果想要提取last hidden layer特征 - # load 配置文件 - vision_cfg = open_clip.CLIPVisionCfg - # 修改output_tokens参数,这个参数设成True,forward后会返回两个值,一个是原来的[CLS] token feature,另一个是全部token的feature - vision_cfg.output_tokens = True - # 把这个参数load进model - model, _, _ = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', vision_cfg = vision_cfg) - 【这个好像不能work】 这个biomedclip封装得很深,缺乏相应的参数来控制输出层。网上搜了很多,很多人都碰到了这个问题 - https://github.com/mlfoundations/open_clip/issues/844 - - 建议直接用Clip试试,后面可以用自己pretrain的CLIP替换 - # 提取1024维度的feature - from transformers import CLIPProcessor, CLIPVisionModel - model_path = "models/clip-vit-large-patch14" - processor = CLIPProcessor.from_pretrained(model_path) - vision_model = CLIPVisionModel.from_pretrained(model_path).to(device) - vision_model.to(device) - images = Image.open(path) - inputs = processor(images=images, return_tensors="pt", padding=True).to(device) - with torch.no_grad(): - vision_outputs = vision_model(**inputs) - #最后一层隐藏层特征,(257,1024), 可以当作local feature - last_hidden_state = vision_outputs.last_hidden_state - #CLS token 特征,(1024),可以当作global feature - pooler_output = vision_outputs.pooler_output - ''' - ''' - 采用hook提取biomedclip last hidden layer特征 - # 注册钩子 - ''' - self.visual_encoder = model.visual.trunk - # model.load_state_dict(torch.load('/mnt/7T/yinong/ckpt/vit/test_125.pth')) - self.visual_encoder.requires_grad_(False) - - is_unfrozen = True # True False - if is_unfrozen: - # 解冻ViT的最后一层norm - self.visual_encoder.norm.weight.requires_grad_(True) - self.visual_encoder.norm.bias.requires_grad_(True) - # 解冻ViT的倒数n层Transformer - # N = -1 # 解冻最后N个transformer block - # for block in self.visual_encoder.blocks[N:].parameters(): # 解冻Vision Encoder的最后N个transformer block - # block.requires_grad = True - # 遍历模型的参数并打印出每个参数的 requires_grad 属性来检查哪些参数正在训练 - # for name, param in self.visual_encoder.named_parameters(): - # print(f"{name}: requires_grad={param.requires_grad}") - # self.ln_vision=nn.LayerNorm(self.visual_encoder.num_features) - # 最早单层MLP的版本 - # self.llama_proj_mlp = nn.Linear(768, 4096) - # 以前的llama_proj_mlps版本 - # self.llama_proj_mlps = nn.Sequential( - # nn.Linear(768, 768), - # nn.Linear(768, 4096), - # ) - # 修改 Yanzhaoshi 2024.09.08 参考Llava-Med的projector - self.llama_proj_mlps = nn.Sequential( - nn.Linear(768, 4096), - # 激活层 可以简单使用 nn.ReLU(), 速度会稍快一些, 或者 采用Llava-Med 的 nn.GELU(), 后者开销会大一些,但在NLP一般表现好于relu - nn.GELU(), - # nn.ReLU(), - nn.Linear(4096, 4096), + if self.has_qformer and self.model_2d_or_3d=='2d': + print('Loading Q-Former') + self.Qformer_list=[] + for i in range(1): + self.Qformer, self.query_tokens = self.init_Qformer( + num_query_token, 1408, freeze_qformer,self.self_training ) - - - elif self.model_2d_or_3d=='3d': - print('-----------------------------3d vit encoder-------------') - model = AutoModel.from_pretrained( - "GoodBaiBai88/M3D-CLIP", - # "GoodBaiBai88/M3D-LaMed-Llama-2-7B", - trust_remote_code=True) - self.visual_encoder=model.vision_encoder - # self.visual_encoder.requires_grad_(False) - - # self.ln_vision=nn.LayerNorm(self.visual_encoder.hidden_size) - load_vit=False - if load_vit==True: - state_dict_b = self.visual_encoder.state_dict() - keys_list = list(state_dict_b.keys()) - file_path="/home/ynwang/MiniGPT-4/minigpt4/ckpt/vit/checkpoint_59.pth" - print("load vit encoder") - ckpt = torch.load(file_path, map_location="cpu") - # 将model_a的权重拷贝到model_b - state_dict_a = ckpt['model'] - for para_idx,(name_a, param_a) in enumerate(state_dict_a.items()): - if para_idx>=144: - break - else: - cur_key=keys_list[para_idx] - if state_dict_b[cur_key].shape!=param_a.shape: - print("shape unmatch:{}".format(cur_key)) - else: - print("load parameter:{}".format(cur_key)) - state_dict_b[cur_key]=param_a - msg = self.visual_encoder.load_state_dict(state_dict_b, strict=False) - - - - self.visual_encoder_list.append(self.visual_encoder.to(self.device)) - - # self.ln_vision_list.append(self.ln_vision.to(self.device)) - self.max_txt_len = max_txt_len - self.max_context_len = max_context_len - self.end_sym = end_sym + self.load_from_pretrained(url_or_filename=q_former_model) # load q-former weights here + + img_f_dim = self.Qformer.config.hidden_size + self.Qformer_list.append(self.Qformer) + print('Loading Q-Former Done') + self.Qformer=self.Qformer_list[0] + elif self.model_2d_or_3d=='2d': + img_f_dim=self.visual_encoder_list[0].num_features + elif self.model_2d_or_3d=='3d': + self.projector=SpatialPoolingProjector() + img_f_dim = self.visual_encoder_list[0].hidden_size + print('Do not use Q-Former here.') + + # self.llama_proj = nn.Linear( + # img_f_dim, 4096 + # ) + # self.seg_layer=nn.Sequential( + # nn.Linear(4096, 4096), + # nn.ReLU(), + # nn.Linear(4096, 4) + # ) + # self.class_layer_4=nn.Linear(524288,16-1).to(self.device) + # self.class_layer_3=nn.Linear(32*4096*3,16-1).to(self.device) + # self.class_layer_2=nn.Linear(32*4096*2,16-1).to(self.device) + # self.class_layer_1=nn.Linear(32*4096*1,16-1).to(self.device) + # self.class_layer_list=[ + # self.class_layer_1,self.class_layer_2,self.class_layer_3,self.class_layer_4 + # ] + # self.boundinb_box_relu=nn.ReLU(inplace=True) - self.prompt_template = prompt_template - self.prompt_list = [] - self.grad_list=[] - # self.freeze_model() + # self.bounding_box_layer=nn.Linear(4*32*4096,4).to(self.device) - # self.visual_encoder_list[0].load_state_dict(torch.load("/home/haoran/Yanzhaoshi/MiniGPT-4/minigpt4/output/minigpt4_stage2_finetune/20240919232/checkpoint_59.pth")) - def freeze_model(self): - visual_num=0 - llama_num=0 - for n,p in self.named_parameters(): - # p.requires_grad=False - if 'visual_encoder' in n or 'llama_proj_mlps' in n or 'ln_vision' in n: - # print(n) - p.requires_grad=True - visual_num+=p.numel() - else: - llama_num+=p.numel() - p.requires_grad=False - self.grad_list.append(p) - - def vit_to_cpu(self): - # self.ln_vision.to("cpu") - # self.ln_vision.float() - self.visual_encoder.to("cpu") - self.visual_encoder.float() - - # 需要修改点1 tokenizer.bos_token组要在最前面,是llama3的开始符 - # 点2 在instruction里加入模态的名称和分隔符 + if self.self_training: + # for name,param in self.llama_proj.named_parameters(): + # param.requires_grad = False + # for class_layer in self.class_layer_list: + # for name,param in class_layer.named_parameters(): + # param.requires_grad = False + for name,param in self.bounding_box_layer.named_parameters(): + param.requires_grad = False + print('freeze head params') + if prompt_path: + with open(prompt_path, 'r') as f: + raw_prompts = f.read().splitlines() + filted_prompts = [raw_prompt for raw_prompt in raw_prompts if "" in raw_prompt] + self.prompt_list = [prompt_template.format(p) for p in filted_prompts] + print('Load {} training prompts'.format(len(self.prompt_list))) + print('Prompt Example \n{}'.format(random.choice(self.prompt_list))) + else: + self.prompt_list = [] + + @classmethod + def init_Qformer(cls, num_query_token, vision_width, freeze,self_training=False): + encoder_config = BertConfig.from_pretrained("bert-base-uncased") + encoder_config.encoder_width = vision_width + # insert cross-attention layer every other block + encoder_config.add_cross_attention = True + encoder_config.cross_attention_freq = 2 + encoder_config.query_length = num_query_token + Qformer = BertLMHeadModel(config=encoder_config) + query_tokens = nn.Parameter( + torch.zeros(1, num_query_token, encoder_config.hidden_size) + ) + query_tokens.data.normal_(mean=0.0, std=encoder_config.initializer_range) + + Qformer.cls = None + Qformer.bert.embeddings.word_embeddings = None + Qformer.bert.embeddings.position_embeddings = None + for layer in Qformer.bert.encoder.layer: + layer.output = None + layer.intermediate = None + + if self_training: + for name, param in Qformer.named_parameters(): + param.requires_grad = False + Qformer = Qformer.eval() + Qformer.train = disabled_train + query_tokens.requires_grad = False + logging.info("freeze Qformer") + + return Qformer, query_tokens + + def hook_for_vit_features(self, image): + # step1 创建钩子的容器 + activations = {} + # step2 定义钩子函数 + def hook_fn(module, input, output): + activations[module] = output + # step3 注册钩子 + hook_handle_norm = self.visual_encoder.norm.register_forward_hook(hook_fn) # block 出来的第一个norm + # step4 前馈过程,得到768特征 + latent = self.visual_encoder(image) # [128, 3, 224, 224] -> [128, 768] + # step5 钩取 last hidden layer output + image_token_embeds = activations[self.visual_encoder.norm][:, 1:, :] # 和Llava-Med一样不使用最前面的CLS,取后面的image patch tokens torch.Size([224, 196, 768]) + # step6 删除钩子和容器 + hook_handle_norm.remove() + activations = {} + return image_token_embeds - def get_context_emb(self, prompt, img_list): - device = img_list[0].device - prompt_segs = prompt.split('') # 分成两部分 - - - - seg_tokens = [ - self.llama_tokenizer( - seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg - for i, seg in enumerate(prompt_segs) - ] - seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] - - - # for idx, seg in enumerate(p_segs[:-1]): # 遍历全部的按照''分隔后的部分,除了最后面的一项 - # # 提取当前子部分的token id 并编码 - # p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) - # p_embed = self.embed_tokens(p_tokens.input_ids) - # # 累加器加上当前seg中的token数量 - # count_current_len += p_embed.size(1) - # # volume visual token 开始位 - # self.start_visual_token_idx.append(count_current_len) - # # 在这个子部分文本后,插入当前volume的视觉特征 - # interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * vis_chunk_size:(idx + 1) * vis_chunk_size]], dim=1)) - # # 累加器加上当前vis_chunk中的token数量 - # count_current_len += vis_chunk_size - # # volume visual token 结束位置 - # self.end_visual_token_idx.append(count_current_len) + def encode_img(self, image, resolution_type, modality=['t1n']): + # image = torch.stack(image, dim=1) - count_current_len=1 - old_version=1 - if old_version==0: - mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] - mixed_embs = torch.cat(mixed_embs, dim=1) - - # 2024.09.07 Yanzhaoshi修改 加入llama的起始符到句首 self.llama_tokenizer.bos_token - #llama:第一个字符是bos开始符 - bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) - # mixed_embs = torch.cat([bos_embed, mixed_embs], dim=1) - - # self.start_visual_token_idx=[2] - # self.end_visual_token_idx=[3] - # self.llama_model.set_index(self.start_visual_token_idx,self.end_visual_token_idx) - eos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) - - mixed_embs = torch.cat([bos_embed, mixed_embs,eos_embed], dim=1) - else: - image_list=[] - cur_att=1 - if cur_att==0: - for index in range(1,5): - image_list.append(img_list[0][:,32*(index-1):32*index]) - mixed_embs = [emb for pair in zip(seg_embs[:-1], image_list) for emb in pair] + [seg_embs[-1]] - - ####################### - # mixed_embs=[] - # for pair in zip(seg_embs[:-1], image_list): - # for emb in pair: - # mixed_embs.append(emb) - - mixed_embs=mixed_embs+[seg_embs[-1]] - - - mixed_embs = torch.cat(mixed_embs, dim=1) - bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) - mixed_embs = torch.cat([bos_embed, mixed_embs], dim=1) - # eos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) - - # mixed_embs = torch.cat([bos_embed, mixed_embs,eos_embed], dim=1) - - elif cur_att==1: - image_list=[] - - # 记录每个volume token的结束位置 example [34, 67, 100, 133] - - for index in range(1,5): - image_list.append(img_list[0][:,32*(index-1):32*index]) - # mixed_embs = [emb for pair in zip(seg_embs[:-1], image_list) for emb in pair] + [seg_embs[-1]] - - count_current_len = 1 # 计算当前prompt长度,取1是因为有开始符,开始符在循环外边会追加。用于找是视觉token的位置 - self.start_visual_token_idx = [] # 记录每个volume token的开始位置 example [2, 35, 68, 101] - self.end_visual_token_idx = [] - mixed_embs=[] - for pair in zip(seg_embs[:-1], image_list): - count_current_len+=1 - self.start_visual_token_idx.append(count_current_len) - for emb in pair: - - # volume visual token 结束位置 - mixed_embs.append(emb) - count_current_len += 32 - self.end_visual_token_idx.append(count_current_len) - self.llama_model.set_index(self.start_visual_token_idx,self.end_visual_token_idx) - mixed_embs=mixed_embs+[seg_embs[-1]] - - mixed_embs = torch.cat(mixed_embs, dim=1) - bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) - eos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) - # mixed_embs = torch.cat([bos_embed, mixed_embs], dim=1) - - mixed_embs = torch.cat([bos_embed, mixed_embs,eos_embed], dim=1) - return mixed_embs - - def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None): # 拼接prompts - if prompts is None or len(prompts) == 0: - # prompts is not provided, just return the original image embedding - return img_embeds, atts_img - elif img_embeds is None: - # prompt is provided but there is no image embedding. return the prompt embedding in right padding - self.llama_tokenizer.padding_side = "right" - prompt_tokens = self.llama_tokenizer( - prompts, - return_tensors="pt", - padding="longest", - add_special_tokens=False - ).to(self.device) - prompt_embeds = self.embed_tokens(prompt_tokens.input_ids) - atts_prompt = prompt_tokens.attention_mask - return prompt_embeds, atts_prompt - else: # 走这里 - # return the multi-modal embedding in right padding - emb_lists = [] - if isinstance(prompts, str): - prompts = [prompts] * len(img_embeds) - self.start_visual_token_idx = None - self.end_visual_token_idx = None - cond_ids_list=[] - for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)): - pn = each_img_embed.shape[-2] # 图像总个数 128 - # pn = each_img_embed.shape[-1] - - if lengths is not None: - each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1]) - each_img_embed = each_img_embed[:lengths[idx] * pn] - p_segs = each_prompt.split('') # 把token的左右内容分开,插入图像特征 - interleave_emb = [] - interleave_ids=[] - '''2024.09.08 Yanzhaoshi修改: 以32为一组, 拆分视觉特征 {instruction} ''' - count_current_len = 1 # 计算当前prompt长度,取1是因为有开始符,开始符会在后面追加。该变量用于找是视觉token的位置 - self.start_visual_token_idx = [] # 记录每个volume token的开始位置 example [2, 35, 68, 101] - self.end_visual_token_idx = [] # 记录每个volume token的结束位置 example [34, 67, 100, 133] - if '' in each_prompt: # 如果检测到了分隔符,就对视觉特征按照volume进行分块插入,每块都会用隔开 - # 2024.09.09 计算每个volume起始和终点位置,确保后面LLM中casual mask的位置正确 - # instruction示例 后面加入了对应模态的名称在前面 - vis_chunk_size = 32 # 一个modality的图像个数,作为一个块 - for idx, seg in enumerate(p_segs[:-1]): # 遍历全部的按照''分隔后的部分,除了最后面的一项 - # 提取当前子部分的token id 并编码 - p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) - #10.13 update ids - interleave_ids += p_tokens.input_ids.squeeze(0).tolist() - - p_embed = self.embed_tokens(p_tokens.input_ids) - # 累加器加上当前seg中的token数量 - count_current_len += p_embed.size(1) - # volume visual token 开始位 - self.start_visual_token_idx.append(count_current_len) - # 在这个子部分文本后,插入当前volume的视觉特征 - interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * vis_chunk_size:(idx + 1) * vis_chunk_size]], dim=1)) - - interleave_ids += self.llama_tokenizer('', add_special_tokens=False).input_ids * (int(pn/4)) - - # 累加器加上当前vis_chunk中的token数量 - count_current_len += vis_chunk_size - # volume visual token 结束位置 - self.end_visual_token_idx.append(count_current_len) - - - else: # 如果没检测到分隔符,就还是用以前的代码,把视觉特征全部一起插入 2D 数据会走这里 - for idx, seg in enumerate(p_segs[:-1]): # 遍历全部的按照''分隔后的部分,除了最后面的一项 - # 提取当前子部分的token id 并编码 - p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) - p_embed = self.embed_tokens(p_tokens.input_ids) - # 累加器加上当前seg中的token数量 - count_current_len += p_embed.size(1) - # volume visual token 开始位 - self.start_visual_token_idx.append(count_current_len) - # 以前这种直接插入全部图像的方法 - interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1)) - # 累加器加上当前vis_chunk中的token数量 - count_current_len += pn - # volume visual token 结束位置 - self.end_visual_token_idx.append(count_current_len) - - wrapped_emb = torch.cat(interleave_emb, dim=1) # 将前面每块的特征合并起来 - - # 2024.09.07 Yanzhaoshi修改 加入llama的起始符到句首 self.llama_tokenizer.bos_token - # 2024.09.23 Yanzhaoshi修改 取消该部分代码,因为后面有了bos,避免加重了 - # bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(img_embeds.device).input_ids) - # wrapped_emb = torch.cat([bos_embed, wrapped_emb], dim=1) - - p_tokens = self.llama_tokenizer( - p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device) # 拼接上最后一块instruction的特征 - - interleave_ids += p_tokens.input_ids.squeeze(0).tolist() - cond_ids_list.append(interleave_ids) - cond_ids_list = torch.tensor(cond_ids_list).to(img_embeds.device) - - #10.13 update - - p_embed = self.embed_tokens(p_tokens.input_ids) - wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1) # 在[1, 131, 4096]后面加上instruction文本[1, 32, 4096] - - # 2024.09.11 Yanzhaoshi 修改 加入eos终止符 self.llama_tokenizer.eos_token - # 2024.10.13 之后训练时去掉这里 - # eos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(img_embeds.device).input_ids) - # wrapped_emb = torch.cat([wrapped_emb, eos_embed], dim=1) - # cond_ids_list = torch.cat([cond_ids_list, torch.tensor(self.llama_tokenizer.eos_token_id).cuda().unsqueeze(0).unsqueeze(0)], dim=1) - - emb_lists.append(wrapped_emb) # 合并后 [1, 163, 4096] - - emb_lens = [emb.shape[1] for emb in emb_lists] - pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) # 计算padding token的embedding [4096] - - max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len # 最大长度max_length 163 - wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone() # [1, 163, 4096] - wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) # [1, 163] - - for i, emb in enumerate(emb_lists): - length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len # 163 - wrapped_embs[i, :length] = emb[:, :length] # 两者是相同的 - wrapped_atts[i, :length] = 1 - return wrapped_embs, wrapped_atts,cond_ids_list - - def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): - """ - Concatenate the batched input embedding and batched output embedding together. - Both the input and the output embedding should be right padded. - """ - input_lens = [] - cat_embs = [] - cat_atts = [] - for i in range(input_embs.size(0)): # input_embs [1, 163, 4096] output_embs [1, 30, 4096] - input_len = input_atts[i].sum() - input_lens.append(input_len) - cat_embs.append( - torch.cat([ - input_embs[i][:input_len], - output_embs[i], - input_embs[i][input_len:] - ]) - ) # cat_embs [193, 4096] - cat_atts.append( - torch.cat([ - input_atts[i][:input_len], - output_atts[i], - input_atts[i][input_len:] - ]) - ) # cat_atts [193] - cat_embs = torch.stack(cat_embs) # [1, 193, 4096] - cat_atts = torch.stack(cat_atts) # [1, 193] - return cat_embs, cat_atts, input_lens - - def tokenize_conversation(self, conv_q, conv_a): - """concatenate conversation and make sure the model is only trained to regress the answer""" - - to_regress_token_ids_list = [] - targets_list = [] - - batch_size = len(conv_q) - for batch_idx in range(batch_size): - questions, answers = conv_q[batch_idx], conv_a[batch_idx] - questions = [self.llama_tokenizer(self.llama_tokenizer.bos_token + q, - return_tensors="pt", - add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it - answers = [self.llama_tokenizer(a + self.end_sym, - return_tensors="pt", - add_special_tokens=False).to(self.device) for a in answers] - cur_id = [] - cur_target = [] - for i in range(len(questions)): - cur_id.append(answers[i].input_ids) - cur_target.append(answers[i].input_ids) - cur_id.append(questions[i].input_ids) - cur_target.append(torch.ones_like(questions[i].input_ids) * -100) - - cur_id.append(answers[-1].input_ids) - cur_target.append(answers[-1].input_ids) - - cur_id = torch.cat(cur_id, dim=1) - cur_target = torch.cat(cur_target, dim=1) - to_regress_token_ids_list.append(cur_id) - targets_list.append(cur_target) - - max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len) - to_regress_token_ids = torch.ones([batch_size, max_len], - dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id - targets = torch.ones([batch_size, max_len], - dtype=cur_id.dtype, device=self.device) * -100 - for batch_idx in range(batch_size): - cur_len = to_regress_token_ids_list[batch_idx].shape[1] - to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len] - targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len] - - to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int) - - return to_regress_token_ids, to_regress_token_attn, targets - - def preparing_embedding(self, samples): - ### prepare input tokens - if 'several_modalities_image' in samples: - # print(samples['several_modalities_image']) - - images=samples['several_modalities_image']['images'] - # print(images.shape) - # print(len(images)) - modalities=samples['several_modalities_image']['modalities'] - # print(modalities) - device=images[0].device - temp=[] - for index,image in enumerate(images): - # print(image.shape) - # img_embeds, img_atts,self_loss = self.encode_img(image,modalities[index]) - # img_embeds, img_atts,self_loss=self.encode_img(image,modalities[index]) - img_embeds, img_atts,self_loss=self.encode_img(image) - - # print(img_embeds.shape) - temp.append(img_embeds) - img_embeds=torch.cat(temp, dim=1) - # print() - elif 'image' in samples: - if type(samples["image"])==list: - images = torch.stack(samples["image"]).to(self.device) - # print(images.shape) - else: - images=samples["image"].to(self.device) - # print(samples["image"]) - # images=samples["image"].to(self.device) - img_embeds, img_atts, self_loss = self.encode_img(images, "LR_Encoding") #LR 图像的维度是[b, 4(模态数), 32(切片数), 3(初始通道数), 224(长), 224(宽)] torch.Size([1, 7, 32, 3, 224, 224]) - # img_embeds [128, 1, 768] img_atts [128, 1] (全1矩阵) - if "HR_resolution" in samples.keys() and samples['HR_resolution'] ==True: #HR (0.1,0.07,nan,nan) - HR_img_embeds, HR_img_atts, HR_self_loss = self.encode_img(samples["HR_image_list"].to(self.device), "HR_Encoding") #HR [224*196*4, 1, 768] - # - # print(HR_img_embeds.shape,img_embeds.shape) - img_embeds=self.cross_attention(img_embeds.reshape(1,-1,768), HR_img_embeds.reshape(1,-1,768)) # 1, 128, 768 - # print('----------------') - # print(img_embeds.shape) - # 修改1 在此处加上projector,映射到LLM维度4096 - img_embeds=self.llama_proj_mlps(img_embeds) # 1, 128, 4096 - - - device=images.device - else: - img_embeds = img_atts = None - # print(img_embeds.shape) - if 'conv_q' in samples: - # handeling conversation datasets - conv_q, conv_a = samples['conv_q'], samples['conv_a'] - - connect_sym = samples['connect_sym'][0] - conv_q = [q.split(connect_sym)for q in conv_q] - conv_a = [a.split(connect_sym) for a in conv_a] - - conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q] - - cond_embeds, cond_atts,_ = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q]) - regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a) - - else: - if "instruction_input" in samples: - instruction = samples["instruction_input"] # [" [caption] There are several MRI sequence from 1 patient. Please make a detailed diagnosis for this patient including the tumor's detailed type. "] - elif self.prompt_list: - instruction = random.choice(self.prompt_list) - else: - instruction = None - - if hasattr(self, 'chat_template') and self.chat_template: - instruction = [self.prompt_template.format(instruct) for instruct in instruction] - - if 'length' in samples: - # the input is a image train (like videos) - bsz, pn, hs = img_embeds.shape - # print(img_embeds.shape) - img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) - - cond_embeds, cond_atts,_ = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length']) - else: - # return self_loss - # print(img_embeds.shape)#(a,32,4096) - # img_embeds=self.cross_attention(img_embeds) - # print('img_embeds_0',img_embeds.shape) - - # img_embeds=img_embeds.reshape(1,-1,*img_embeds.shape[-1:])#32,32, - - # print('img_embeds_1',img_embeds.shape) - if img_embeds.size(0)>8 and self.has_qformer==True: - img_embeds=img_embeds.reshape(img_embeds.size(0),img_embeds.size(1),64,64) - # img_embeds=self.spatial_pooling_layer_2d(img_embeds) - # print('img_embeds',img_embeds.shape) - img_embeds=img_embeds.permute(1,0,2,3) - img_embeds=self.conv3d_layer(img_embeds) - img_embeds=img_embeds.reshape(img_embeds.size(0),img_embeds.size(1),4096) - img_embeds=self.conv3d_proj(img_embeds).reshape(1,-1,4096) - else: # 走这里 - img_embeds=img_embeds.reshape(1,-1,4096) #1次输入很多张图-1 instruction - - # print('img_embeds',img_embeds.shape) - - # print('after downsampling',img_embeds.shape) - # print(img_embeds.shape)#(1,a*32,4096) - # 将multi-modal prompt进行拼接 - cond_embeds, cond_atts,_ = self.prompt_wrap(img_embeds, img_atts, instruction) # cond_embeds [1, 163, 4096] cond_atts [1, 163] - - ### prepare target tokens - self.llama_tokenizer.padding_side = "right" - text = [t + self.end_sym for t in samples["answer"]] # target (self.end_sym = ###) 'This patient was diagnosed with Brain Metastases Tumor, and further diagnosed as Brain Metastases Tumor###' - # print(f"text {text}") - regress_tokens = self.llama_tokenizer( - text, - return_tensors="pt", - padding="longest", - truncation=True, - max_length=self.max_txt_len, - add_special_tokens=False - ).to(self.device) # regress is target (对gt report进行tokenize) torch.Size([1, 30]) - # print(regress_tokens.shape) - regress_token_ids = regress_tokens.input_ids # [1, 30] - # print("regress_token_ids:",regress_token_ids) - regress_atts = regress_tokens.attention_mask # [1, 30] - part_targets = regress_token_ids.masked_fill( - regress_token_ids == self.llama_tokenizer.pad_token_id, -100 - ) # 不关注填充的token [1, 30] - # print(part_targets) - # print(part_targets.shape) - if 'class_target' in samples: - class_target=samples['class_target'] - class_target=class_target.to(device) - else: - class_target=None - regress_embeds = self.embed_tokens(regress_token_ids) # gt report embedding [1, 30, 4096] - - - return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets, img_embeds, class_target, self_loss - - def forward(self, samples, reduction='mean'): - # prepare the embedding to condition and the embedding to regress + device = image.device + encoder_index=self.modalities[modality[0]] - # not_load=True - if self.self_training == False: - cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets,input_embs,class_target,self_loss = \ - self.preparing_embedding(samples) # - else: - self.visual_encoder=self.visual_encoder_list[0] - image=samples['image'].to(self.device) + if self.model_2d_or_3d=='2d': # 按照2d单张图片的方式处理3d影像 if len(image.shape) > 4: - image = image.reshape(-1, *image.shape[-3:]) - with self.maybe_autocast(): - loss,_,_,latent=self.visual_encoder(image) - # self_loss=self.encode_img(samples) - return {"loss": loss} - # print(samples['bounding_box']) - # if 'bounding_box' in samples and samples['bounding_box']!=['None']: - # label=samples['bounding_box'] - # # bounding_box_loss_fn=iou_loss() - # # x=input_embs - # # x=self.decoder_linear(x) - # # for blk in self.decoder_blocks: - # # x = blk(x) - # # x = self.decoder_norm(x) - - # # # predictor projection - # # x = self.decoder_pred(x.view(1,-1)) - # # # print(input_embs.view(1,-1).shape) - # bounding_box=self.boundinb_box_relu(self.bounding_box_layer(input_embs.view(1,-1))) - # loss_fn=nn.SmoothL1Loss() - # # print('bounding box',bounding_box[0]) - # # print('label',label[0]) - # # print('bounding box',bounding_box) - # # print('label',label) - # iou_loss=calculate_iou(label[0],bounding_box[0]) - # giou_loss=loss_fn(label,bounding_box) - # bounding_box_loss=giou_loss - - # else: - # iou_loss=1 - # bounding_box_loss=None - # print(cond_embeds.shape) - # concat the embedding to condition and the embedding to regress - - # print('class target is ',class_target) - if class_target!=None: - img_num=input_embs.shape[1]/32 - # class_predict=self.class_layer_list[int(img_num)-1](input_embs.view(input_embs.size(0), -1)) - classification_loss_fn=torch.nn.CrossEntropyLoss() - # print('predict shape is',class_predict.shape) - # print('target shape is ',class_target.shape) - class_loss=classification_loss_fn(class_predict,class_target) - else: - class_loss=None - inputs_embeds, attention_mask, input_lens = \ - self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) # 把instruction和target concat - - # get bos token embedding - bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id - bos_embeds = self.embed_tokens(bos) # [1, 1, 4096] - bos_atts = cond_atts[:, :1] # [1, 1] 值为1 - - # # add bos token at the begining - # print(bos_embeds.shape,inputs_embeds.shape) - inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) # [1, 173, 4096] - attention_mask = torch.cat([bos_atts, attention_mask], dim=1) # [1, 173] 全1 - - # ensemble the final targets - targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], - dtype=torch.long).to(self.device).fill_(-100) # - - for i, target in enumerate(part_targets): # 将前面的token id变成-100,后面gt的index还是对应的 - targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos - # print(targets) - # print(inputs_embeds.shape) - # mask=samples['mask'] - # if bounding_box_loss==Tr: - # print('self_loss',self_loss) - # print(targets.shape) - - if self_loss==None: - with self.maybe_autocast(): # 将multimodal prompt输入至LLM - outputs = self.llama_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, - start_visual_idx = self.start_visual_token_idx, - end_visual_idx = self.end_visual_token_idx, - return_dict=True, - labels=targets, - reduction=reduction - ) - hidden_states=outputs.hidden_states[-1] # outputs.hidden_states [1, 173, 4096] hidden_states [173, 4096] - # class_predict=self.class_layer(hidden_states[:,:32,:].view(1,-1)) - # print(outputs.keys()) - # print(hidden_states.shape) - # if self_loss: - # loss=self_loss - if "bounding_box" in samples: - gt_bounding_box=samples["bounding_box"] - # print(self.bounding_box_label) - # print(targets.shape) - # print(self.location_label) - output_token_index = (targets == self.bounding_box_label).nonzero() - # print(targets) - if len(output_token_index)>0: - location_state=hidden_states[output_token_index[0][-1]] - bounding_box=self.seg_layer(location_state.view(1, -1)) - # print(gt_bounding_box[0]) - # print('predict',bounding_box) - # print('gt',gt_bounding_box) - # print(gt_bounding_box) - - L1_loss=self.L1_loss_fn(bounding_box,gt_bounding_box) - # print(bounding_box,gt_bounding_box) - giou_loss=bbox_giou_loss(bounding_box,gt_bounding_box) - # print(1-giou_loss) - - if giou_loss>0: - # print(L1_loss) - loss=outputs.loss+0.2*L1_loss+1.2*giou_loss - else: - loss=0.8*outputs.loss+0.2*L1_loss - else: - loss=outputs.loss - # print(giou_loss) - # print(output_token_index) - # if len(output_token_index): - # addon_index = torch.ones_like(output_token_index)*(-1) - # addon_index[:, 0] = 0 - # output_token_index += addon_index - # print(output_token_index) - # class_predict - # loss = class_loss - # elif bounding_box_loss: - # loss=bounding_box_loss - - else: - loss=outputs.loss - # print("loss is ",loss) - - # loss=outputs.loss - # if self_loss: - # # print('sssssssssssss') - # loss=loss+0.25*self_loss - - if torch.isnan(loss): - print(samples['image_name']) - return {"loss": loss} - - def embed_tokens(self, token_ids): - if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model - embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) - else: - embeds = self.llama_model.base_model.embed_tokens(token_ids) - return embeds - - @torch.no_grad() - def multi_generate( - self, - samples, - texts, - num_beams=1, - max_new_tokens=160, - min_length=1, - top_p=0.9, - repetition_penalty=1, - length_penalty=1, - temperature=1, - do_sample=False, - stop_words_ids=[2], - ): - ''' - function for generate test use - ''' - - stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( - stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) - # images = torch.stack(images, dim=1) - # img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device)) - - if 'several_modalities_image' in samples: - images=samples['several_modalities_image']['images'] - # print(len(images)) - modalities=samples['several_modalities_image']['modalities'] - # print(modalities) - # device=images[0].device - temp=[] - for index,image in enumerate(images): - # print(image.shape) - # img_embeds, img_atts,self_loss = self.encode_img(image,modalities[index]) - img_embeds, img_atts,self_loss = self.encode_img(image.to(self.device).half(),modalities[index]) - # print(img_embeds.shape) - temp.append(img_embeds) - img_embeds=torch.cat(temp, dim=1) - - image_lists = [[image_emb[None]] for image_emb in img_embeds] - # print(len(image_lists),len(texts)) - # for i in image_lists: - # print(i[0].shape) - batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] - - batch_size = len(batch_embs) - max_len = max([emb.shape[1] for emb in batch_embs]) - emb_dim = batch_embs[0].shape[2] - dtype = batch_embs[0].dtype - device = batch_embs[0].device - - embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) - attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) - for i, emb in enumerate(batch_embs): - emb_len = emb.shape[1] - embs[i, -emb_len:] = emb[0] - attn_mask[i, -emb_len:] = 1 - with self.maybe_autocast(): - outputs = self.llama_model.generate( - inputs_embeds=embs, - attention_mask=attn_mask, - max_new_tokens=max_new_tokens, - num_beams=num_beams, - length_penalty=length_penalty, - temperature=temperature, - do_sample=do_sample, - min_length=min_length, - top_p=top_p, - repetition_penalty=repetition_penalty, - stopping_criteria=stopping_criteria, - ) - - answers = [] - for output_token in outputs: - if output_token[0] == 0: - output_token = output_token[1:] - output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) - output_texts = output_texts.split('')[0] # remove the stop sign - output_texts = output_texts.replace("", "") - output_texts = output_texts.split(r'[/INST]')[-1].strip() - answers.append(output_texts) - - return answers - - @torch.no_grad() - def generate( - self, - images, - texts, - num_beams=1, - max_new_tokens=300, - min_length=1, - top_p=0.9, - repetition_penalty=1, - length_penalty=1, - temperature=1, # 可以选择0.1, 0.2, 0.4, 0.6, 0.8, 1 - do_sample=False, - stop_words_ids=[2], - ): - ''' - function for generate test use - ''' - - stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( - stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) - # images = torch.stack(images, dim=1) - img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device)) - - image_lists = [[image_emb[None]] for image_emb in img_embeds] - # print(len(image_lists),len(texts)) - # for i in image_lists: - # print(i[0].shape) - batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] - - batch_size = len(batch_embs) - max_len = max([emb.shape[1] for emb in batch_embs]) - emb_dim = batch_embs[0].shape[2] - dtype = batch_embs[0].dtype - device = batch_embs[0].device - - embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) - attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) - for i, emb in enumerate(batch_embs): - emb_len = emb.shape[1] - embs[i, -emb_len:] = emb[0] - attn_mask[i, -emb_len:] = 1 - - with self.maybe_autocast(): - outputs = self.llama_model.generate( - inputs_embeds=embs, - attention_mask=attn_mask, - max_new_tokens=max_new_tokens, - num_beams=num_beams, - length_penalty=length_penalty, - temperature=temperature, - do_sample=do_sample, - min_length=min_length, - top_p=top_p, - repetition_penalty=repetition_penalty, - stopping_criteria=stopping_criteria, - ) - - answers = [] - for output_token in outputs: - if output_token[0] == 0: - output_token = output_token[1:] - output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) - output_texts = output_texts.split('')[0] # remove the stop sign - output_texts = output_texts.replace("", "") - output_texts = output_texts.split(r'[/INST]')[-1].strip() - answers.append(output_texts) - - return answers - - # def single_predict_class(self,samples,device): - # images=samples['image'] - - # img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device)) - # class_predict=self.class_layer_list[0](img_embeds.view(1, -1).float()) - # return class_predict - - def predict_class(self,samples,device): - # img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device)) - # print(type(img_embeds)) - # print(img_embeds.shape) - images=samples['several_modalities_image']['images'] - # print(len(images)) - modalities=samples['several_modalities_image']['modalities'] - # print(modalities) - # device=images[0].device - temp=[] - num=len(images) - for index,image in enumerate(images): + image = image.reshape(-1, *image.shape[-3:]) # [128, 3, 224, 224] 合并前面几个维度,batch、modality和slices + elif self.model_2d_or_3d=='3d': + if len(image.shape) > 4: # print(image.shape) - # img_embeds, img_atts,self_loss = self.encode_img(image,modalities[index]) - img_embeds, img_atts,self_loss=self.encode_img(image.to(device).half(),modalities[index]) - # print(img_embeds.shape) - temp.append(img_embeds) - img_embeds=torch.cat(temp, dim=1) - - # class_predict=self.class_layer_list[num-1](img_embeds.view(1, -1).float()) - # if class_target!=None: - # classification_loss_fn=torch.nn.CrossEntropyLoss() - # class_loss=classification_loss_fn(class_predict,class_target) - return class_predict - - - @torch.no_grad() - def predict_boundingbox( - self, - images, - texts, - HR_images=None, - num_beams=1, - max_new_tokens=300, - min_length=1, - top_p=0.9, - repetition_penalty=1, - length_penalty=1, - temperature=1, - do_sample=False, - stop_words_ids=[2], - output_hidden_states=True - ): - ''' - function for generate test use - ''' - - stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( - stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) - # images = torch.stack(images, dim=1) - img_embeds, atts_img, self_loss = self.encode_img(images.to(self.device), "LR_Encoding") - - if HR_images!=[]: - # print('-----------') - HR_embeds,HR_atts_img,HR_loss = self.encode_img(HR_images.to(self.device), "HR_Encoding") - img_embeds = self.cross_attention(img_embeds.reshape(1,-1,768).float(),HR_embeds.reshape(1,-1,768).float()) - # 修改1 在此处加上projector,映射到LLM维度4096 - img_embeds=self.llama_proj_mlps(img_embeds) # 1, 128, 4096 - img_embeds=img_embeds.reshape(1,-1,*img_embeds.shape[-1:]) - - if img_embeds.size(0)>8: - - img_embeds=img_embeds.reshape(1,-1,4096) - # print(img_embeds.shape) - - # img_embeds=self.conv3d_proj(img_embeds) - else: - img_embeds=img_embeds.reshape(1,-1,4096) - # print(img_embeds.shape) - - - image_lists = [[image_emb[None]] for image_emb in img_embeds] - - batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] - - batch_size = len(batch_embs) - max_len = max([emb.shape[1] for emb in batch_embs]) - emb_dim = batch_embs[0].shape[2] - dtype = batch_embs[0].dtype - device = batch_embs[0].device - - embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) - attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) - for i, emb in enumerate(batch_embs): - emb_len = emb.shape[1] - embs[i, -emb_len:] = emb[0] - attn_mask[i, -emb_len:] = 1 - - # with self.maybe_autocast(): - # outputs = self.llama_model.generate( - # inputs_embeds=embs, - # attention_mask=attn_mask, - # max_new_tokens=max_new_tokens, - # num_beams=num_beams, - # length_penalty=length_penalty, - # temperature=temperature, - # do_sample=do_sample, - # min_length=min_length, - # top_p=top_p, - # repetition_penalty=repetition_penalty, - # stopping_criteria=stopping_criteria, - # ) - + image = image.reshape(-1, 1, 32,*image.shape[-2:]) + # self.ln_vision=self.ln_vision_list[encoder_index].to(self.device) + self.visual_encoder=self.visual_encoder_list[encoder_index].to(self.device) with self.maybe_autocast(): - outputs = self.llama_model.generate( - inputs_embeds=embs, - attention_mask=attn_mask, - max_new_tokens=max_new_tokens, - num_beams=num_beams, - length_penalty=length_penalty, - temperature=temperature, - do_sample=do_sample, - min_length=min_length, - top_p=top_p, - repetition_penalty=repetition_penalty, - stopping_criteria=stopping_criteria, - return_dict_in_generate=True, - output_hidden_states=output_hidden_states, - pad_token_id=2, - eos_token_id=2 - ) - hidden_states=outputs['hidden_states'] - # [39,38,66177,1163,44,60158,40249,47,51,34,39] - - answers = [] - top_2_list=self.llama_model.get_top() - if top_2_list[0]!=None and top_2_list[1]!=None: - top2_answer=self.llama_tokenizer.decode(top_2_list, skip_special_tokens=True) - # print(top2_answer) - self.llama_model.delete_max_id() - - for output_token in outputs['sequences']: - if output_token[0] == 0: - output_token = output_token[1:] - - output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) - ''' - 打印了一下output_texts,为什么会这么长,不是应该生成完前面的答案之后,自己预测出结束符然后结束吗?换句话说后面应该都是结束符才对,而结束符会在decoder时skip掉 - - 'This patient was diagnosed with Embryonal tumours### Short diagnosis: Embryonal tumours### Diagnosis: MedulloblastomaThis patient was diagnosed - with Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### - Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal - tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: - Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### - Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal - tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: - Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### - Diagnosis: Embryonal tum' - - 后面几行代码看起来像处理结束符的,训练时有加上结束符吗 - ''' - output_texts = output_texts.split('')[0] # remove the stop sign - output_texts = output_texts.replace("", "") - output_texts = output_texts.split(r'[/INST]')[-1].strip() - answers.append(output_texts) - # print(output_texts) - - # print('answers',answers) - return answers - def get_input_embeddings(self): - return self.llama_model.get_input_embeddings() - - @torch.no_grad() - def generate_step( - self, - images, - texts, - HR_images=None, - num_beams=1, - echo=False, - max_new_tokens=500, - min_length=1, - top_p=0.9, - repetition_penalty=1, - length_penalty=1, - temperature=-1, # -1 0.01 0.05 0.1 0.2 0.4 0.6 0.8 1 - do_sample=False, - stop_words_ids=[2], - output_hidden_states=True - ): - ''' - function for generate test use - ''' - # breakpoint() - # self.llama_model.to(self.device) - stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( - stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) - # images = torch.stack(images, dim=1) - # breakpoint() - img_embeds, atts_img, self_loss = self.encode_img(images.to(self.device), "LR_Encoding") - - if HR_images!=[]: - # print('-----------') - HR_embeds,HR_atts_img,HR_loss = self.encode_img(HR_images.to(self.device), "HR_Encoding") - img_embeds = self.cross_attention(img_embeds.reshape(1,-1,768).float(),HR_embeds.reshape(1,-1,768).float()) - # 修改1 在此处加上projector,映射到LLM维度4096 - img_embeds=self.llama_proj_mlps(img_embeds) # 1, 128, 4096 - bsz=img_embeds.shape[0] - img_embeds=img_embeds.reshape(1,-1,*img_embeds.shape[-1:]) - # print(f'img_embeds{img_embeds.shape}') - if img_embeds.size(0)>8: - - img_embeds=img_embeds.reshape(1,-1,4096) - # print(img_embeds.shape) - - # img_embeds=self.conv3d_proj(img_embeds) - else: - img_embeds=img_embeds.reshape(1,-1,4096) - # print(img_embeds.shape) - - - # image_lists = [[image_emb[None]] for image_emb in img_embeds] - - # batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] - inputs_embeds,attention_mask,input_ids=self.prompt_wrap(img_embeds,atts_img,texts) - - # get bos token embedding - bos = torch.ones_like(attention_mask[:, :1]) * self.llama_tokenizer.bos_token_id - bos_embeds = self.embed_tokens(bos) # [1, 1, 4096] - bos_atts = attention_mask[:, :1] # [1, 1] 值为1 - - # add bos token at the begining - inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) # [1, 173, 4096] - # print(f'inputs_embeds{inputs_embeds.shape}') - attention_mask = torch.cat([bos_atts, attention_mask], dim=1) # [1, 173] 全1 - input_ids = torch.cat([bos, input_ids], dim=1) - - - # bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(batch_embs.device).input_ids) - # batch_embs = torch.cat([bos_embed, batch_embs], dim=1) - # batch_embs=[batch_embs] - - - - batch_size = len(inputs_embeds) - # min_len = min([emb.shape[1] for emb in inputs_embeds]) - # max_len = max([emb.shape[1] for emb in inputs_embeds]) - emb_dim = inputs_embeds.shape[2] - dtype = inputs_embeds.dtype - device = inputs_embeds.device - min_prompt_len = min(len(t) for t in inputs_embeds) # 84 - max_prompt_len = max(len(t) for t in inputs_embeds) # 84 - total_len = max_new_tokens + min_prompt_len - - ############# - # embeds = torch.full((bsz, total_len, inputs_embeds.shape[-1]), 0, dtype=torch.float, device="cuda") # 30 339 4096 - # for k, t in enumerate(inputs_embeds): - # embeds[k, : len(t)] = t - # for k, t in enumerate(input_ids): - # tokens[k, : len(t)] = t - ########### - - self.pad_token_id=self.llama_tokenizer.pad_token_id - pad_id = self.pad_token_id - tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.int, device=device) - embs = torch.zeros([batch_size, total_len, emb_dim], dtype=dtype, device=device) - for k, t in enumerate(inputs_embeds): - embs[k, : len(t)] = t - for k, t in enumerate(input_ids): - tokens[k, : len(t)] = t - - prev_pos=0 - eos_reached = torch.tensor([False] * bsz, device=self.device) - input_text_mask = tokens != pad_id - self.stop_token_id = int(self.llama_tokenizer.encode("###")[-1]) - stop_tokens = torch.tensor([self.stop_token_id], device=self.device) - - # attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) - - # for i, emb in enumerate(batch_embs): - # emb_len = emb.shape[1] - # embs[i, -emb_len:] = emb[0] - # attn_mask[i, -emb_len:] = 1 - # batch_embs=batch_embs[0] - - # eos_token=self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(device).input_ids[0] - # prev_pos=0 - past_key_values=None - # self.stop_token_id = int(self.llama_tokenizer.encode("###")[-1]) - # stop_tokens = torch.tensor([self.stop_token_id], device="cuda") - # eos_reached = torch.tensor([False] * bsz, device="cuda") - with_probs = False - calculate_prob=False - # breakpoint() - - for cur_pos in range(min_prompt_len, total_len): - with self.maybe_autocast(): - outputs = self.llama_model( - # attention_mask=attention_mask[:, prev_pos:cur_pos], - # position_ids=position_ids, - past_key_values=past_key_values, - inputs_embeds=embs[:, prev_pos:cur_pos], - use_cache=True, - output_hidden_states=output_hidden_states, - return_dict=True, - start_visual_idx = self.start_visual_token_idx, - end_visual_idx = self.end_visual_token_idx, - ) - - logits = outputs["logits"] # 1 176 128256 - past_key_values = outputs["past_key_values"] - - if temperature > 0: - probs = torch.softmax(logits[:, -1] / temperature, dim=-1) - next_token = sample_top_p(probs, top_p) + # loss,_,_,latent=self.visual_encoder(image.to(device)) + # return loss + if self.model_2d_or_3d=='3d': + latent_list=[] + for i in range(len(image)): + sub_image=image[i].reshape(-1,1,32,256,256) + # print(sub_image.shape) + latent,_=self.visual_encoder(sub_image.to(device)) + latent=latent[:,:-1,:] + # image_embeds = self.ln_vision(latent) + image_embeds=latent + image_embeds=self.projector(image_embeds) + # print(image_embeds.shape) + latent_list.append(image_embeds) + + latent_list=torch.stack(latent_list,dim=0) + latent_list=latent_list.reshape(-1,latent_list.size(-1)) + # print('latent_list',latent_list.shape) + inputs_llama=self.llama_proj(latent_list) else: - next_token = torch.argmax(logits[:, -1], dim=-1) - - next_token = next_token.reshape(-1) - # only replace token if prompt has already been generated - next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) - import copy - if calculate_prob: - probs = torch.softmax(logits[:, -1], dim=-1) - index_list=[15,16,17,18,19,20,21,22,23,24,605,806,717,1032] - new_logits=copy.deepcopy(logits[0][0]) - tumor_list=[new_logits[k] for k in index_list] - tumor_list = torch.stack(tumor_list) - probabilities = torch.softmax(tumor_list, dim=0) - token_dict={probabilities[k]:index_list[k] for k in range(len(index_list))} - top_2=sorted(token_dict.keys(), reverse=True)[:2] - # print(top_2[0].item(),top_2[1].item()) - top1=token_dict[top_2[0]] - top2=token_dict[top_2[1]] - print(f'',top_2[0].item(),f'',top_2[1].item()) - with_probs=False - calculate_prob=False - if next_token==449: - with_probs=True - if with_probs: - if next_token==62: - calculate_prob=True - - tokens[:, cur_pos] = next_token - next_token_embed = self.llama_model.get_input_embeddings()(next_token) - embs[:, cur_pos] = next_token_embed - # breakpoint() - eos_reached |= (~input_text_mask[:, cur_pos]) & ( - torch.isin(next_token, stop_tokens.to(self.device)) - ) - prev_pos = cur_pos - if all(eos_reached): - break - - # out_tokens, out_logprobs = [], [] - - ############################## - # for i, toks in enumerate(tokens.tolist()): - # # cut to max gen len - # start = 0 if echo else len(input_ids[i]) - # toks = toks[start: len(input_ids[i]) + max_gen_len] - # probs = None - # # if logprobs: - # # probs = token_logprobs[i][start : len(input_ids[i]) + max_gen_len] - # # cut to after eos tok if any - # for stop_token in [self.stop_token_id]: - # try: - # eos_idx = toks.index(stop_token) - # toks = toks[:eos_idx] - # # probs = probs[:eos_idx] if logprobs else None - # except ValueError: - # pass - # out_tokens.append(toks) - # # out_logprobs.append(probs) - # # return (out_tokens, out_logprobs if logprobs else None) - # return out_tokens - # hidden_states=outputs['hidden_states'] - # # [39,38,66177,1163,44,60158,40249,47,51,34,39] - - # answers = [] - # # top_2_list=self.llama_model.get_top() - # # if top_2_list[0]!=None and top_2_list[1]!=None: - # # top2_answer=self.llama_tokenizer.decode(top_2_list, skip_special_tokens=True) - # # # print(top2_answer) - # # self.llama_model.delete_max_id() - - # for output_token in tokens: - # if output_token[0] == 0: - # output_token = output_token[1:] - - # output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) - # output_texts = output_texts.split('')[0] # remove the stop sign - # output_texts = output_texts.replace("", "") - # output_texts = output_texts.split(r'[/INST]')[-1].strip() - # answers.append(output_texts) - # # print(output_texts) - out_tokens, out_logprobs = [], [] - for i, toks in enumerate(tokens.tolist()): - # cut to max gen len - start = 0 if echo else len(input_ids[i]) - toks = toks[start: len(input_ids[i]) + max_new_tokens] - probs = None - # if logprobs: - # probs = token_logprobs[i][start : len(input_ids[i]) + max_gen_len] - # cut to after eos tok if any - for stop_token in [self.stop_token_id]: - try: - eos_idx = toks.index(stop_token) - toks = toks[:eos_idx] - # probs = probs[:eos_idx] if logprobs else None - except ValueError: - pass - out_tokens.append(toks) - - answers = self.llama_tokenizer.decode(out_tokens[0]) - - return answers - - - - def generate_segmentation( - self, - images, - texts, - num_beams=1, - max_new_tokens=300, - min_length=1, - top_p=0.9, - repetition_penalty=1, - length_penalty=1, - temperature=1, - do_sample=False, - stop_words_ids=[2], - output_hidden_states=True - ): - ''' - function for generate test use - ''' - - stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( - stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) - # images = torch.stack(images, dim=1) - img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device)) - - image_lists = [[image_emb[None]] for image_emb in img_embeds] - # print(len(image_lists),len(texts)) - # for i in image_lists: - # print(i[0].shape) - batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] - - batch_size = len(batch_embs) - max_len = max([emb.shape[1] for emb in batch_embs]) - emb_dim = batch_embs[0].shape[2] - dtype = batch_embs[0].dtype - device = batch_embs[0].device - - embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) - attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) - for i, emb in enumerate(batch_embs): - emb_len = emb.shape[1] - embs[i, -emb_len:] = emb[0] - attn_mask[i, -emb_len:] = 1 - - with self.maybe_autocast(): - outputs = self.llama_model.generate( - inputs_embeds=embs, - attention_mask=attn_mask, - max_new_tokens=max_new_tokens, - num_beams=num_beams, - length_penalty=length_penalty, - temperature=temperature, - do_sample=do_sample, - min_length=min_length, - top_p=top_p, - repetition_penalty=repetition_penalty, - stopping_criteria=stopping_criteria, - return_dict_in_generate=True, + # print('image vit shape',image.shape) + # ((32((a,b,c)->32(np.zoom) sequence images number)*4),3,224,224) + b,c,w,h=image.shape # b 128 c 3 w h 224 + + if b < 256: + #### 修改 通过钩子找到 biomedclip last hidden layer output #### + # latent=self.visual_encoder(image.to(device)) + latent = self.hook_for_vit_features(image.to(device)) + if resolution_type == "LR_Encoding": # 对于LR图像,我们取196 token的平均,表示global的features + latent = latent.mean(dim=1) # 224, 196, 768 -> 224, 768 + elif resolution_type == "HR_Encoding": # 对于HR图像,我们保留196 token,表示detail的features + latent = latent + + + # image_embeds = self.ln_vision(latent) + image_embeds=latent # [128, 768] + else: + latent_list=[] + sub_batch=int(b/2) + for i in range(0,b,sub_batch): + sub_image=image[i:i+sub_batch] + # latent=self.visual_encoder(sub_image.to(device)) + #### 修改 通过钩子找到 biomedclip last hidden layer output #### + latent = self.hook_for_vit_features(sub_image.to(device)) + if resolution_type == "LR_Encoding": # 对于LR图像,我们取196 token的平均,表示global的features + latent = latent.mean(dim=1) # 224, 196, 768 -> 224, 768 + elif resolution_type == "HR_Encoding": # 对于HR图像,我们保留196 token,表示detail的features + latent = latent + # latent = self.ln_vision(latent) + latent_list.append(latent) + latent_list = torch.stack(latent_list, dim=0) + image_embeds = latent_list.reshape(-1, latent_list.size(-1)) - output_hidden_states=output_hidden_states - ) - hidden_states=outputs['hidden_states'] - # print(hidden_states[0].shape) - - # print(outputs.keys()) - location=(outputs['sequences'] == self.bounding_box_label).nonzero().flatten() - # print(self.bounding_box_label) - # print(len(outputs)) - # print('location',location) - # print(len(hidden_states)) - # print(hidden_states.shape) - predict_boundingbox=[] - if location.numel() != 0: - location=location[-1] - # print(outputs.shape) - # print(location) - # print(hidden_states[0].shape) - location_state=hidden_states[location] - # print(location_state.view(1,-1).shape) - predict_boundingbox=self.seg_layer(location_state.view(1,-1).float()) - # # print(1111111111111111111) - # print(predict_boundingbox.cpu()) - answers = [] - for output_token in outputs['sequences']: - if output_token[0] == 0: - output_token = output_token[1:] - - output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) - output_texts = output_texts.split('')[0] # remove the stop sign - output_texts = output_texts.replace("", "") - output_texts = output_texts.split(r'[/INST]')[-1].strip() - answers.append(output_texts) - # print(output_texts) - if predict_boundingbox!=[]: - return answers,predict_boundingbox - else: - return answers,[] - # print('answers',answers) - # return answers,predict_boundingbox - - - - - def predict_class_(self,images,reduction='mean'): - # img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device)) - # print(type(img_embeds)) - # print(img_embeds.shape) - cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets,input_embs,class_target,self_loss = \ - self.preparing_embedding(images) - - inputs_embeds, attention_mask, input_lens = \ - self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) - - # get bos token embedding - bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id - bos_embeds = self.embed_tokens(bos) - bos_atts = cond_atts[:, :1] - - # add bos token at the begining - inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) - attention_mask = torch.cat([bos_atts, attention_mask], dim=1) - - # ensemble the final targets - targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], - dtype=torch.long).to(self.device).fill_(-100) - - - with self.maybe_autocast(): - outputs = self.llama_model( - inputs_embeds=inputs_embeds, - attention_mask=attention_mask, + if self.has_qformer==True: + image_embeds=self.v_q_project(image_embeds) + if self.model_2d_or_3d=='2d': + image_embeds=image_embeds.view(image_embeds.size(0), 1, -1) # [128, 1, 768] + # else: + # print(image_embeds.shape) + + loss=None + if self.has_qformer and self.model_2d_or_3d=='2d': + image_atts = torch.ones(image_embeds.size()[:-1], dtype=torch.long).to(device) + query_tokens = self.query_tokens.expand(image_embeds.shape[0], -1, -1) + query_output = self.Qformer.bert( + query_embeds=query_tokens, + encoder_hidden_states=image_embeds, + encoder_attention_mask=image_atts, return_dict=True, - labels=None, - reduction=reduction ) - hidden_states=outputs.hidden_states - class_predict=self.class_layer(hidden_states[:,:32,:].view(1,-1).float()) - - - # class_predict=self.class_layer(img_embeds.view(1, -1).float()) - # if class_target!=None: - # classification_loss_fn=torch.nn.CrossEntropyLoss() - # class_loss=classification_loss_fn(class_predict,class_target) - return class_predict - - def predict_boundingbox_(self,samples): - if 'several_modalities_image' in samples: - images=samples['several_modalities_image']['images'] - modalities=samples['several_modalities_image']['modalities'] - # print(modalities) - device=images[0].device - temp=[] - for index,image in enumerate(images): - # print(image.shape) - img_embeds, img_atts,self_loss = self.encode_img(image.to(self.device),modalities[index]) - temp.append(img_embeds) - img_embeds=torch.cat(temp, dim=1) - # img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device)) - x=img_embeds.float() - x=self.decoder_linear(x) - for blk in self.decoder_blocks: - x = blk(x) - x = self.decoder_norm(x) - - # predictor projection - x = self.decoder_pred(x.view(1,-1)) - # print(x.shape) - return x - - @torch.no_grad() - def multi_select(self, images, texts, answers, num_cand=None): - all_losses = [] - for answer in answers: - choice_samples = { - 'image': images, - 'instruction_input': texts, - 'answer': answer - } - loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1) - all_losses.append(loss) - torch.cuda.empty_cache() - all_losses = torch.cat(all_losses, dim=-1) - if num_cand is not None: - for i in range(all_losses.shape[0]): - all_losses[i, num_cand[i]:] = 9999 - output_class_ranks = torch.argsort(all_losses, dim=-1) - return output_class_ranks.tolist() - -def sample_top_p(probs, p): - """ - Perform top-p (nucleus) sampling on a probability distribution. - - Args: - probs (torch.Tensor): Probability distribution tensor. - p (float): Probability threshold for top-p sampling. + # print(query_output.last_hidden_state.shape) + + inputs_llama = self.llama_proj(query_output.last_hidden_state) + elif self.model_2d_or_3d=='2d': # 映射到4096 + # 修改1 这里去掉projector,而是放到后面crossatt后面 + # inputs_llama=self.llama_proj(image_embeds) # [128, 1, 4096] + inputs_llama = image_embeds + # else: + # # print(image_embeds.shape) + # # image_embeds = image_embeds[:, 1:, :] + # # bs, pn, hs = image_embeds.shape + # # image_embeds = image_embeds.view(bs, int(pn / 4), int(hs * 4)) + # image_embeds=self.projector(image_embeds) + # inputs_llama = self.llama_proj(image_embeds) + # loss=None + atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(image.device) # [128,1] + return inputs_llama, atts_llama,loss + + @classmethod + def from_config(cls, cfg): + vit_model = cfg.get("vit_model", "eva_clip_g") + q_former_model = cfg.get("q_former_model", "https://storage.googleapis.com/sfr-vision-language-research/LAVIS/models/BLIP2/blip2_pretrained_flant5xxl.pth") + img_size = cfg.get("image_size") + num_query_token = cfg.get("num_query_token") + llama_model = cfg.get("llama_model") + + drop_path_rate = cfg.get("drop_path_rate", 0) + use_grad_checkpoint = cfg.get("use_grad_checkpoint", False) + vit_precision = cfg.get("vit_precision", "fp16") + freeze_vit = cfg.get("freeze_vit", True) + has_qformer = cfg.get("has_qformer", True) + freeze_qformer = cfg.get("freeze_qformer", True) + low_resource = cfg.get("low_resource", False) + device_8bit = cfg.get("device_8bit", 0) + + prompt_path = cfg.get("prompt_path", "") + prompt_template = cfg.get("prompt_template", "") + max_txt_len = cfg.get("max_txt_len", 32) + end_sym = cfg.get("end_sym", '\n') + model_2d_or_3d=cfg.get("model_2d_or_3d", '2d') + print('model_2d_or_3d',model_2d_or_3d) + model = cls( + vit_model=vit_model, + q_former_model=q_former_model, + img_size=img_size, + drop_path_rate=drop_path_rate, + use_grad_checkpoint=use_grad_checkpoint, + vit_precision=vit_precision, + freeze_vit=freeze_vit, + has_qformer=has_qformer, + freeze_qformer=freeze_qformer, + num_query_token=num_query_token, + llama_model=llama_model, + prompt_path=prompt_path, + prompt_template=prompt_template, + max_txt_len=max_txt_len, + end_sym=end_sym, + low_resource=low_resource, + device_8bit=device_8bit, + model_2d_or_3d=model_2d_or_3d + ) + + ckpt_path = cfg.get("ckpt", "") # load weights of MiniGPT-4 + print(ckpt_path) + vit_stage1=False + if vit_stage1==True: + file_path="/home/ynwang/MAE-med_copy/open_clip/src/output_Yanzhao/2024_11_11-11_04_12-model_ViT-B-16-lr_5e-05-b_32-j_64-p_amp/checkpoints/epoch_93.pt" + print("load vit encoder") + ckpt = torch.load(file_path, map_location="cpu") + # 将model_a的权重拷贝到model_b + state_dict_a = ckpt['state_dict'] + state_dict_b = model.state_dict() + keys_list = [i for i in list(state_dict_b.keys()) if "visual_encoder" in i] + print(keys_list) + # breakpoint() + print(f'len of keys_list:{len(keys_list)}') + start_idx=0 + for para_idx,(name_a, param_a) in enumerate(state_dict_a.items()): + if para_idx>=151: # 读取完keys_list全部的150个key,最后两个是norm.weight and norm.bias + break + if para_idx==0: + continue + # if name_a in ["mask_token","decoder_pos_embed"]: + # continue + else: + cur_key=keys_list[start_idx] + # print(f'b:{cur_key},a:{name_a}') + # if not cur_key.split(".")[-1] == name_a.split(".")[-1]: # double check一下是否有layer名称不匹配的情况,发现没有 + # print(f"Warning! Layers do not match! cur_key: {cur_key}, biomedclip layer: {name_a}") + start_idx+=1 + if state_dict_b[cur_key].shape!=param_a.shape: + print("shape unmatch:{}".format(cur_key)) + else: + print("load parameter:{}".format(cur_key)) + # state_dict_b[cur_key]=param_a + state_dict_b[cur_key].copy_(param_a) - Returns: - torch.Tensor: Sampled token indices. + + # for name_b, param_b in state_dict_b.items(): + # if name_b in state_dict_a: + # param_a = state_dict_a[name_b] + # # 检查形状是否匹配,以避免错误 + # if param_a.shape == param_b.shape: + # print(f"Transferring weights for layer: {name_b}") + # state_dict_b[name_b].copy_(param_a) + # else: + # print(f"Shape mismatch at layer: {name_b}, cannot transfer weights.") + # else: + # # print(f"Layer {name_b} not found in model_a.") + # pass + msg = model.load_state_dict(state_dict_b, strict=False) + + if ckpt_path: + print("Load MiniGPT-4 Checkpoint: {}".format(ckpt_path)) + ckpt = torch.load(ckpt_path, map_location="cpu") + # 将model_a的权重拷贝到model_b + state_dict_a = ckpt['model'] + state_dict_b = model.state_dict() + for name_a,param_a in state_dict_a.items(): + print(name_a) + for name_b, param_b in state_dict_b.items(): + # print(name_b) + # llama_model.base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight + # llama_model.base_model.model.model.layers.0.self_attn.q_proj.lora_A.default.weight + if 'lora' in name_b: + # breakpoint() + new_name_b=name_b.replace('default.','') - Note: - Top-p sampling selects the smallest set of tokens whose cumulative probability mass - exceeds the threshold p. The distribution is renormalized based on the selected tokens. - """ - probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) - probs_sum = torch.cumsum(probs_sort, dim=-1) - mask = probs_sum - probs_sort > p - probs_sort[mask] = 0.0 - probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) - next_token = torch.multinomial(probs_sort, num_samples=1) - next_token = torch.gather(probs_idx, -1, next_token) - return next_token \ No newline at end of file + else: + new_name_b=name_b + if new_name_b in state_dict_a: + param_a = state_dict_a[new_name_b] + # 检查形状是否匹配,以避免错误 + if param_a.shape == param_b.shape: + print(name_b,param_b.shape,param_a.shape) + + # print(f"Transferring weights for layer: {name_b}") + state_dict_b[name_b].copy_(param_a) + else: + print(f"Shape mismatch at layer: {name_b}, cannot transfer weights.") + else: + # print(f"Layer {name_b} not found in model_a.") + pass + msg = model.load_state_dict(state_dict_b, strict=False) + # msg = model.load_state_dict(ckpt['model'], strict=False) + + return model + +''' +Transferring weights for layer: llama_model.base_model.model.model.layers.0.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.0.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.0.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.0.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.1.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.1.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.1.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.1.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.2.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.2.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.2.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.2.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.3.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.3.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.3.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.3.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.4.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.4.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.4.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.4.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.5.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.5.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.5.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.5.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.6.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.6.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.6.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.6.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.7.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.7.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.7.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.7.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.8.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.8.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.8.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.8.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.9.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.9.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.9.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.9.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.10.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.10.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.10.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.10.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.11.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.11.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.11.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.11.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.12.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.12.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.12.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.12.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.13.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.13.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.13.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.13.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.14.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.14.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.14.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.14.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.15.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.15.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.15.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.15.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.16.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.16.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.16.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.16.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.17.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.17.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.17.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.17.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.18.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.18.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.18.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.18.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.19.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.19.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.19.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.19.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.20.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.20.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.20.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.20.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.21.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.21.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.21.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.21.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.22.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.22.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.22.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.22.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.23.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.23.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.23.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.23.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.24.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.24.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.24.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.24.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.25.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.25.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.25.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.25.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.26.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.26.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.26.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.26.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.27.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.27.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.27.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.27.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.28.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.28.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.28.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.28.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.29.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.29.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.29.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.29.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.30.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.30.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.30.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.30.self_attn.v_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.31.self_attn.q_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.31.self_attn.q_proj.lora_B.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.31.self_attn.v_proj.lora_A.weight +Transferring weights for layer: llama_model.base_model.model.model.layers.31.self_attn.v_proj.lora_B.weight +Shape mismatch at layer: cross_attention.Wq.weight, cannot transfer weights. +Shape mismatch at layer: cross_attention.Wq.bias, cannot transfer weights. +Shape mismatch at layer: cross_attention.Wk.weight, cannot transfer weights. +Shape mismatch at layer: cross_attention.Wk.bias, cannot transfer weights. +Shape mismatch at layer: cross_attention.Wv.weight, cannot transfer weights. +Shape mismatch at layer: cross_attention.Wv.bias, cannot transfer weights. +Shape mismatch at layer: cross_attention.output_linear.weight, cannot transfer weights. +Shape mismatch at layer: cross_attention.output_linear.bias, cannot transfer weights. +Transferring weights for layer: ln_vision.weight +Transferring weights for layer: ln_vision.bias +Transferring weights for layer: v_q_project.weight +Transferring weights for layer: v_q_project.bias +''' \ No newline at end of file