BrainVLM / minigpt4 /models /minigpt_base.py
jcccy's picture
change cuda
3e17c15
import logging
import random
import torch
from torch.cuda.amp import autocast as autocast
import torch.nn as nn
import open_clip
from minigpt4.common.registry import registry
from minigpt4.models.base_model import BaseModel
from transformers import StoppingCriteria, StoppingCriteriaList,AutoTokenizer, AutoModel
from minigpt4.models.base_model import LayerNorm
from minigpt4.conversation.conversation import StoppingCriteriaSub
from minigpt4.models.eva_vit import Block
from minigpt4.models.mae_vit import mae_vit_large_patch16
# from transformers import
import math
import torch.nn.functional as F
from einops import rearrange
from einops.layers.torch import Rearrange
from open_clip.factory import HF_HUB_PREFIX, _MODEL_CONFIGS
import pdb
class SpatialPoolingProjector(nn.Module):
def __init__(self, image_size=(32, 256, 256), patch_size=(8, 16, 16), in_dim=768, out_dim=768, layer_type='linear', layer_num=6, pooling_type='spatial', pooling_size=2):
# print(in_dim,out_dim)
super().__init__()
self.in_dim = in_dim
self.pooling_size = pooling_size
self.num_patches_pre = [img // pch for img, pch in zip(image_size, patch_size)]
#8,16,16
self.num_patches_post = [num // pooling_size for num in self.num_patches_pre]
#4,8,8
if layer_type == 'linear':
depth = int(layer_num)
modules = [nn.Linear(in_dim, out_dim)]
for _ in range(1, depth):
modules.append(nn.Linear(out_dim, out_dim))
self.projector = nn.Sequential(*modules)
elif layer_type == 'mlp':
depth = int(layer_num)
modules = [nn.Linear(in_dim, out_dim)]
for _ in range(1, depth):
modules.append(nn.GELU())
modules.append(nn.Linear(out_dim, out_dim))
self.projector = nn.Sequential(*modules)
else:
print("Projector error!")
self.pooling_type = pooling_type
def forward(self, x):
B = x.shape[0] # B*N*D
# print(x.shape)
if self.pooling_type == 'spatial':
to_3d = Rearrange("b (p1 p2 p3) d -> b d p1 p2 p3", b=B, d=self.in_dim, p1=self.num_patches_pre[0], p2=self.num_patches_pre[1], p3=self.num_patches_pre[2])
x = to_3d(x)
x = F.avg_pool3d(x, kernel_size=self.pooling_size, stride=self.pooling_size)
to_seq = Rearrange("b d p1 p2 p3 -> b (p1 p2 p3) d", b=B, d=self.in_dim, p1=self.num_patches_post[0], p2=self.num_patches_post[1], p3=self.num_patches_post[2])
x = to_seq(x)
elif self.pooling_type == 'sequence':
x = x.permute(0, 2, 1) #b d n
x = F.avg_pool1d(x, kernel_size=self.pooling_size**3, stride=self.pooling_size**3)
x = x.permute(0, 2, 1) #b n d
x = rearrange(x, "b n d -> (b n) d")
# print(x.shape)
x = self.projector(x)
x = rearrange(x, "(b n) d -> b n d", b=B)
return x
@property
def proj_out_num(self):
num = 1
for n in self.num_patches_post:
num *= n
return num
def calculate_iou(box1,box2):
y1max, x1max, y1min, x1min = box1[0], box1[1], box1[2], box1[3]
y2max, x2max, y2min, x2min = box2[0], box2[1], box2[2], box2[3]
#计算两个框的面积
s1 = (y1max - y1min + 1.) * (x1max - x1min + 1.)
s2 = (y2max - y2min + 1.) * (x2max - x2min + 1.)
#计算相交部分的坐标
xmin = max(x1min,x2min)
ymin = max(y1min,y2min)
xmax = min(x1max,x2max)
ymax = min(y1max,y2max)
inter_h = max(ymax - ymin + 1, 0)
inter_w = max(xmax - xmin + 1, 0)
intersection = inter_h * inter_w
union = s1 + s2 - intersection
#计算iou
iou = intersection / union
return iou
def calculate_giou_loss(ground_truth, predicted):
"""
计算GIOU损失
"""
iou = calculate_iou(ground_truth, predicted)
# 计算GIOU的其他三个项
enclose_y1 = torch.min(ground_truth[:, 0], predicted[:, 0])
enclose_x1 = torch.min(ground_truth[:, 1], predicted[:, 1])
enclose_y2 = torch.max(ground_truth[:, 2], predicted[:, 2])
enclose_x2 = torch.max(ground_truth[:, 3], predicted[:, 3])
enclose_area = torch.clamp(enclose_x2 - enclose_x1, min=0) * torch.clamp(enclose_y2 - enclose_y1, min=0)
print(iou)
giou = iou - (enclose_area - calculate_iou(ground_truth, predicted)) / (enclose_area + 1e-10)
giou_loss = 1.0 - giou
return giou_loss.mean()
class mulit_modality_Attention(nn.Module):
def __init__(self, in_channels=16, out_channels=8,emb_dim=4096, output_dim=4096,att_dropout=0.0, aropout=0.0):
super(mulit_modality_Attention, self).__init__()
self.emb_dim = emb_dim
self.scale = emb_dim ** -0.5
# self.proj_in = nn.Linear(in_channels, emb_dim, kernel_size=1, stride=1, padding=0)
self.Wq = nn.Linear(emb_dim, emb_dim)
self.Wk = nn.Linear(emb_dim, emb_dim)
self.Wv = nn.Linear(emb_dim, emb_dim)
self.softmax = nn.Softmax(dim=-1)
self.output_linear=nn.Linear(emb_dim,emb_dim)
self.output_proj = nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1))
def forward(self, x_list, context=None, pad_mask=None):
'''
:param x: [1-4, 32, 4096]
:param context: [batch_szie, seq_len, emb_dim]
:param pad_mask: [batch_size, seq_len, seq_len]
:return:
'''
b, h, w = x_list.shape
# query_list=[]
# key_list=[]
# value_list=[]
# for i in range(b):
# query_list.append(self.Wq(x_list[i,:,:]))
# key_list.append(self.Wk(x_list[i,:,:]))
# value_list.append(self.Wv(x_list[i,:,:]))
# print(query_list[0].shape)
query_list=self.Wq(x_list)
key_list=self.Wk(x_list)
value_list=self.Wv(x_list)
# attention_score=torch.einsum('bid,bjd->bij',query_list,key_list)
q_c,q_w,q_h=query_list.shape
# query_list=query_list.reshape(q_c,-1,q_w,q_h)
# key_list=key_list.reshape(-1,q_c,q_w,q_h)
# print("shape1",query_list.shape)
# print('shape2',key_list.shape)
# attention_score=torch.einsum('bij')
attention_score=torch.matmul(query_list,key_list.transpose(-1,-2))
attention_score=self.softmax(attention_score*self.scale)
# print(attention_score.shape)
attention_score=attention_score.reshape(q_c,-1,q_w,q_w)
value_list=value_list.reshape(-1,q_c,q_w,q_h)
output=torch.matmul(attention_score,value_list)
output=self.output_linear(output)
o_b,o_c,o_w,o_d=output.shape
output=output.reshape(1,o_b*o_c,o_w,o_d)
# print(output.shape)
if o_b*o_c==16:
output=self.output_proj(output)
output=output.reshape(1,-1,o_w,o_d)
# # [batch_size, h*w, seq_len]
# att_weights = torch.einsum('bid,bjd -> bij', Q, K)
# att_weights = att_weights * self.scale
# if pad_mask is not None:
# # [batch_size, h*w, seq_len]
# att_weights = att_weights.masked_fill(pad_mask, -1e9)
# att_weights = F.softmax(att_weights, dim=-1)
# out = torch.einsum('bij, bjd -> bid', att_weights, V) # [batch_size, h*w, emb_dim]
# out = rearrange(out, 'b (h w) c -> b c h w', h=h, w=w) # [batch_size, c, h, w]
# out = self.proj_out(out) # [batch_size, c, h, w]
# print(out.shape)
return output
# 以前出现Nan的cross att版本
# class resolution_attention(nn.Module):
# def __init__(self, in_channels=16, out_channels=8,emb_dim=768, output_dim=768,att_dropout=0.0, aropout=0.0):
# super(resolution_attention, self).__init__()
# self.emb_dim = emb_dim
# self.scale = emb_dim ** -0.5
# # self.proj_in = nn.Linear(in_channels, emb_dim, kernel_size=1, stride=1, padding=0)
# # self.batch_normal_layer_hr=nn.BatchNorm1d(1)
# # self.batch_normal_layer_lr=nn.BatchNorm1d(1)
# self.Wq = nn.Linear(emb_dim, emb_dim)
# self.Wk = nn.Linear(emb_dim, emb_dim)
# self.Wv = nn.Linear(emb_dim, emb_dim)
# self.bn_layer1=nn.BatchNorm1d(emb_dim)
# self.bn_layer2=nn.BatchNorm1d(emb_dim)
# self.softmax = nn.Softmax(dim=-1)
# self.output_linear=nn.Linear(emb_dim,emb_dim)
# self.relu_fn=nn.ReLU()
# # self.output_proj = nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1))
# self.drop_out=nn.Dropout(0.)
# def forward(self, HR_image, LR_image, context=None, pad_mask=None):#这里命名反了,HR实际上是lr,lr是hr
# '''
# :param x: [1-4, 32, 4096]
# :param context: [batch_szie, seq_len, emb_dim]
# :param pad_mask: [batch_size, seq_len, seq_len]
# :return:
# '''
# #batch size is 1
# HR_image=self.bn_layer1(HR_image.reshape(-1,HR_image.size(-1))).reshape(1,-1,HR_image.size(-1)) # 1, 128, 4096
# LR_image=self.bn_layer2(LR_image.reshape(-1,HR_image.size(-1))).reshape(1,-1,HR_image.size(-1)) # 1, 512, 4096
# query_list=self.Wq(HR_image) # 1, 128, 4096
# key_list=self.Wk(LR_image) # 1, 512, 4096
# value_list=self.Wv(LR_image) # 1, 512, 4096
# attention_score=torch.matmul(query_list,key_list.transpose(-1,-2)) # 1, 128, 512
# attention_score=self.softmax(attention_score*self.scale) # 1, 128, 512
# output=torch.matmul(attention_score,value_list) # 1, 128, 4096
# output=self.output_linear(output) + HR_image # 1, 128, 4096
# return output
def attention(query, key, value, mask=None, dropout=None):
d_k = query.size(-1) # 768
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) # 1,5,32,1,784 vs 1,5,32,1,4 vs 1, 160, 512
if mask is not None:
scores = scores.masked_fill(mask == 0, -1e9)
p_attn = F.softmax(scores, dim=-1)
if dropout is not None:
p_attn = dropout(p_attn)
return torch.matmul(p_attn, value), p_attn
class resolution_attention(nn.Module):
def __init__(self, in_channels=16, out_channels=8, emb_dim=768, output_dim=768, dropout=0.1, aropout=0.0):
super(resolution_attention, self).__init__()
self.emb_dim = emb_dim
self.Wq = nn.Linear(emb_dim, emb_dim)
self.Wk = nn.Linear(emb_dim, emb_dim)
self.Wv = nn.Linear(emb_dim, emb_dim)
self.attn = None
self.output_linear = nn.Linear(emb_dim,emb_dim)
self.dropout = nn.Dropout(p=dropout) # 注意力得分后的 Dropout
self.dropout_2 = nn.Dropout(p=dropout) # 残差连接后的 Dropout
self.norm = nn.LayerNorm(emb_dim) # 残差连接后的 LayerNorm
def forward(self, LR_image, HR_image, context=None, mask=None):
'''
:param x: [1-4, 32, 768]
:param context: [batch_szie, seq_len, emb_dim]
:param pad_mask: [batch_size, seq_len, seq_len]
:return:
'''
version = 2 # or 1 2
if version == 0:
# previous version 全部图像的LR看全部图像的HR,这个时候如果不同样本模态不一样的话,可能会出现问题
LR_image = LR_image.reshape(-1,LR_image.size(-1)).reshape(1,-1,LR_image.size(-1)) # 1, 128, 768
HR_image = HR_image.reshape(-1,HR_image.size(-1)).reshape(1,-1,HR_image.size(-1)) # 1, 512, 768
elif version == 1:
# 修改version 每张图像的LR看HR
LR_image = LR_image.view(1, -1, 32, 1, LR_image.size(-1))
HR_image = HR_image.view(1, -1, 32, 4, HR_image.size(-1)) # HR中有4倍的图像数量,即32*4
elif version == 2:
# breakpoint()
# print(f"LR_image.shape:{LR_image.shape},HR_image.shape:{HR_image.shape}")
# 修改 ViT HR 196*224*4, 1, 768 LR 224, 1, 768
LR_image = LR_image.view(1, -1, 32, 1, LR_image.size(-1))
HR_image = HR_image.view(1, -1, 32, 4*196, HR_image.size(-1)) # HR中有4倍的图像数量,外加196个visual token,即32*4*196
query_list=self.Wq(LR_image) # version3: 1,7,32,1,768 version2: 1,5,32,1,768 version1: 1, 160, 768
key_list=self.Wk(HR_image) # version3: 1,7,32,784,768 version2: 1,5,32,4,768 version1: 1, 640, 768
value_list=self.Wv(HR_image) # version3: 1,7,32,784,768 version2: 1,5,32,4,768 version1: 1, 640, 768
x, self.attn = attention(query_list, key_list, value_list, mask=mask, dropout=self.dropout) # 1,5,32,1,768
if version in [1, 2]:
# 把维度变换一下到 1, _, 768
x = x.view(1,-1,LR_image.size(-1))
query_list = query_list.view(1,-1,LR_image.size(-1))
x = self.output_linear(x) # [1, 128, 768]
# 使用残差连接 采用Transformer的技术细节
x = self.norm(query_list + self.dropout_2(x)) # [1, 128, 768]
return x
def bbox_giou_loss(pred_bboxes, target_bboxes):
"""
Compute the Generalized IoU (GIOU) loss between predicted bounding boxes and target bounding boxes.
Args:
- pred_bboxes: Predicted bounding boxes, tensor of shape (N, 4), where N is the number of bounding boxes.
Each bounding box is represented as [ymax, xmax, ymin, xmin].
- target_bboxes: Target bounding boxes, tensor of shape (N, 4), where N is the number of bounding boxes.
Each bounding box is represented as [ymax, xmax, ymin, xmin].
Returns:
- giou_loss: Computed GIOU loss.
"""
x_pred_max=torch.max(pred_bboxes[:,1],pred_bboxes[:,3])
x_pred_min=torch.min(pred_bboxes[:,1],pred_bboxes[:,3])
y_pred_min=torch.min(pred_bboxes[:,0],pred_bboxes[:,2])
y_pred_max=torch.max(pred_bboxes[:,0],pred_bboxes[:,2])
x_gt_max=torch.max(target_bboxes[:,1],target_bboxes[:,3])
x_gt_min=torch.min(target_bboxes[:,1],target_bboxes[:,3])
y_gt_min=torch.min(target_bboxes[:,0],target_bboxes[:,2])
y_gt_max=torch.max(target_bboxes[:,0],target_bboxes[:,2])
ymin = torch.max(y_pred_min, y_gt_min)
xmin = torch.max(x_pred_min, x_gt_min)
ymax = torch.min(y_pred_max, y_gt_max)
xmax = torch.min(x_pred_max, x_gt_max)
intersection = torch.clamp((xmax - xmin), min=0) * torch.clamp((ymax - ymin), min=0)
# Calculate union
pred_area = (x_pred_max - x_pred_min) * (y_pred_max - y_pred_min)
target_area = (x_gt_max - x_gt_min) * (y_gt_max - y_gt_min)
union = pred_area + target_area - intersection
iou=intersection / union
# Calculate loss
# giou_loss = 1 - giou
# print(iou)
giou_loss=1-iou
return giou_loss.mean()
class MiniGPTBase(BaseModel):
"""
Base class for MiniGPT-4 and MiniGPT-v2
"""
def __init__(
self,
vit_model="eva_clip_g",
img_size=224,
drop_path_rate=0,
use_grad_checkpoint=False,
vit_precision="fp16",
freeze_vit=False,
llama_model="",
max_txt_len=32,
max_context_len=3800,
prompt_template="",
end_sym='\n',
low_resource=False, # use 8 bit and put vit in cpu
device_8bit=0, # the device of 8bit model should be set when loading and cannot be changed anymore.
lora_r=8, # or 1 lora_r means lora is not used #lora_r=8,lora_a=32
lora_target_modules=["q_proj", "v_proj"],
lora_alpha=32, # or 64 128/256
lora_dropout=0.05,
modality_number=5,
model_2d_or_3d="3d",
self_training=False
):
super().__init__()
self.self_training=self_training
freeze_vit=not self_training
self.model_2d_or_3d=model_2d_or_3d
print('self.model_2d_or_3d',self.model_2d_or_3d)
print('self_training is ',self_training)
if self.self_training==False:
self.llama_model, self.llama_tokenizer = self.init_llm(
llama_model_path=llama_model,
low_resource=low_resource,
low_res_device=device_8bit,
lora_r=lora_r,
lora_target_modules=lora_target_modules,
lora_alpha=lora_alpha,
lora_dropout=lora_dropout,
)
self.visual_encoder_list=[]
self.ln_vision_list=[]
self.modality_number=modality_number
# 加速一下,先把模型挂载到GPU上,不然太慢
self.llama_model = self.llama_model.cuda()
# 2024.09.07 Yanzhaoshi修改: add_special_tokens 给tokenizer加上特殊字符,特殊字符应作为一个整体
self.llama_tokenizer.add_special_tokens({'additional_special_tokens':["<box>", "<Img>", "</Img>", "<t>"]})
# self.llama_tokenizer.add_special_tokens({'additional_special_tokens':[f'<class_{}>' for i in range(12)]})
# If there is a mismatch between tokenizer vocab size and embedding matrix, self.llama_model.base_model.model.model
# throw a warning and then expand the embedding matrix
if len(self.llama_tokenizer) > self.llama_model.get_input_embeddings().weight.shape[0] and lora_r>0:
print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.")
self.llama_model.base_model.model.model.resize_token_embeddings(new_num_tokens = len(self.llama_tokenizer))
else:
self.llama_model.base_model.resize_token_embeddings(new_num_tokens = len(self.llama_tokenizer))
# self.llama_tokenizer.add_special_tokens({'additional_special_tokens':["<box>"]})
# breakpoint()
self.start_visual_token_idx = None # 记录下第一个视觉token的位置
self.end_visual_token_idx = None # 记录下最后一个视觉token的位置
# self.llama_model.resize_token_embeddings(len(self.llama_tokenizer))
# box_token_id = self.llama_tokenizer.convert_tokens_to_ids('<box>')
# BOX_TOKEN='<box>'
# self.llama_tokenizer.add_tokens(["<Img>", "</Img>"], special_tokens=True)
self.bounding_box_label = self.llama_tokenizer.convert_tokens_to_ids("<box>")
# Langerhans cell histiocytosis
# print('meningioma',self.llama_tokenizer.convert_tokens_to_ids("meningioma"))
# print('self.bounding_box_label',self.bounding_box_label)
# self.location_label=self.llama_tokenizer.convert_ids_to_tokens([2287, 317, 351, 986, 284, 29892, 4423, 29901, 29871])
# self.class_layer=nn.Linear(32*4096,8).to(self.device)
# self.boundinb_box_relu=nn.ReLU(inplace=True)
self.cross_attention=resolution_attention()
self.L1_loss_fn=nn.SmoothL1Loss()
# if self.model_2d_or_3d=='2d':
# self.spatial_pooling_layer_2d=SpatialPoolingProjector(in_dim=4096,out_dim=4096)
# self.conv_layer=nn.Conv3d(32,16,kernel_size=3,stride=1,padding=1)
# self.conv3d_layer=nn.Sequential(
# nn.Conv3d(32,16,kernel_size=3,stride=1,padding=1),
# nn.Conv3d(16,8,kernel_size=3,stride=1,padding=1)
# )
# self.conv3d_proj=nn.Linear(4096,4096)
# self.bounding_box_layer=nn.Linear(4*32*4096,4).to(self.device)
# decoder_embed_dim=512
# decoder_num_heads=16
# self.decoder_linear=nn.Linear(4096,512).to(self.device)
# self.decoder_blocks = nn.ModuleList([
# Block(decoder_embed_dim, decoder_num_heads, 4, qkv_bias=True, qk_scale=None, norm_layer=nn.LayerNorm)
# for i in range(1)])
# self.decoder_norm = nn.LayerNorm(decoder_embed_dim)
# self.decoder_pred = nn.Linear(decoder_embed_dim*128, 4, bias=True)
for i in range(self.modality_number):
# print(freeze_vit)
# self.visual_encoder,self.ln_vision = self.init_vision_encoder(
# vit_model, img_size, drop_path_rate, use_grad_checkpoint, vit_precision, freeze_vit,self.self_training,i
# )
if self.model_2d_or_3d=='2d':
# import json
# with open("/grp01/saas_medml/yinong/ckpt/models--microsoft--BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/snapshots/f658b0b9eeca714b891d004b2d714c7ad3705170/open_clip_config.json", "r") as f:
# config = json.load(f)
# model_cfg = config["model_cfg"]
# preprocess_cfg = config["preprocess_cfg"]
# # 提取的是CLS特征
# model_name='biomedclip_local'
# if (not model_name.startswith(HF_HUB_PREFIX)
# and model_name not in _MODEL_CONFIGS
# and config is not None):
# _MODEL_CONFIGS[model_name] = model_cfg
# # tokenizer = get_tokenizer(model_name)
# model, _, _ = open_clip.create_model_and_transforms(model_name=model_name,pretrained='/grp01/saas_medml/yinong/ckpt/models--microsoft--BiomedCLIP-PubMedBERT_256-vit_base_patch16_224/snapshots/f658b0b9eeca714b891d004b2d714c7ad3705170/open_clip_pytorch_model.bin',**{f"image_{k}": v for k, v in preprocess_cfg.items()})
# 提取的是CLS特征
model, _, _ = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224')
'''
2024.09.02 关于视觉encoder的设置
# 如果想要提取last hidden layer特征
# load 配置文件
vision_cfg = open_clip.CLIPVisionCfg
# 修改output_tokens参数,这个参数设成True,forward后会返回两个值,一个是原来的[CLS] token feature,另一个是全部token的feature
vision_cfg.output_tokens = True
# 把这个参数load进model
model, _, _ = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', vision_cfg = vision_cfg)
【这个好像不能work】 这个biomedclip封装得很深,缺乏相应的参数来控制输出层。网上搜了很多,很多人都碰到了这个问题
https://github.com/mlfoundations/open_clip/issues/844
建议直接用Clip试试,后面可以用自己pretrain的CLIP替换
# 提取1024维度的feature
from transformers import CLIPProcessor, CLIPVisionModel
model_path = "models/clip-vit-large-patch14"
processor = CLIPProcessor.from_pretrained(model_path)
vision_model = CLIPVisionModel.from_pretrained(model_path).to(device)
vision_model.to(device)
images = Image.open(path)
inputs = processor(images=images, return_tensors="pt", padding=True).to(device)
with torch.no_grad():
vision_outputs = vision_model(**inputs)
#最后一层隐藏层特征,(257,1024), 可以当作local feature
last_hidden_state = vision_outputs.last_hidden_state
#CLS token 特征,(1024),可以当作global feature
pooler_output = vision_outputs.pooler_output
'''
'''
采用hook提取biomedclip last hidden layer特征
# 注册钩子
'''
self.visual_encoder = model.visual.trunk
# model.load_state_dict(torch.load('/mnt/7T/yinong/ckpt/vit/test_125.pth'))
self.visual_encoder.requires_grad_(False)
is_unfrozen = True # True False
if is_unfrozen:
# 解冻ViT的最后一层norm
self.visual_encoder.norm.weight.requires_grad_(True)
self.visual_encoder.norm.bias.requires_grad_(True)
# 解冻ViT的倒数n层Transformer
# N = -1 # 解冻最后N个transformer block
# for block in self.visual_encoder.blocks[N:].parameters(): # 解冻Vision Encoder的最后N个transformer block
# block.requires_grad = True
# 遍历模型的参数并打印出每个参数的 requires_grad 属性来检查哪些参数正在训练
# for name, param in self.visual_encoder.named_parameters():
# print(f"{name}: requires_grad={param.requires_grad}")
# self.ln_vision=nn.LayerNorm(self.visual_encoder.num_features)
# 最早单层MLP的版本
# self.llama_proj_mlp = nn.Linear(768, 4096)
# 以前的llama_proj_mlps版本
# self.llama_proj_mlps = nn.Sequential(
# nn.Linear(768, 768),
# nn.Linear(768, 4096),
# )
# 修改 Yanzhaoshi 2024.09.08 参考Llava-Med的projector
self.llama_proj_mlps = nn.Sequential(
nn.Linear(768, 4096),
# 激活层 可以简单使用 nn.ReLU(), 速度会稍快一些, 或者 采用Llava-Med 的 nn.GELU(), 后者开销会大一些,但在NLP一般表现好于relu
nn.GELU(),
# nn.ReLU(),
nn.Linear(4096, 4096),
)
elif self.model_2d_or_3d=='3d':
print('-----------------------------3d vit encoder-------------')
model = AutoModel.from_pretrained(
"GoodBaiBai88/M3D-CLIP",
# "GoodBaiBai88/M3D-LaMed-Llama-2-7B",
trust_remote_code=True)
self.visual_encoder=model.vision_encoder
# self.visual_encoder.requires_grad_(False)
# self.ln_vision=nn.LayerNorm(self.visual_encoder.hidden_size)
load_vit=False
if load_vit==True:
state_dict_b = self.visual_encoder.state_dict()
keys_list = list(state_dict_b.keys())
file_path="/home/ynwang/MiniGPT-4/minigpt4/ckpt/vit/checkpoint_59.pth"
print("load vit encoder")
ckpt = torch.load(file_path, map_location="cpu")
# 将model_a的权重拷贝到model_b
state_dict_a = ckpt['model']
for para_idx,(name_a, param_a) in enumerate(state_dict_a.items()):
if para_idx>=144:
break
else:
cur_key=keys_list[para_idx]
if state_dict_b[cur_key].shape!=param_a.shape:
print("shape unmatch:{}".format(cur_key))
else:
print("load parameter:{}".format(cur_key))
state_dict_b[cur_key]=param_a
msg = self.visual_encoder.load_state_dict(state_dict_b, strict=False)
self.visual_encoder_list.append(self.visual_encoder.to(self.device))
# self.ln_vision_list.append(self.ln_vision.to(self.device))
self.max_txt_len = max_txt_len
self.max_context_len = max_context_len
self.end_sym = end_sym
self.prompt_template = prompt_template
self.prompt_list = []
self.grad_list=[]
# self.freeze_model()
# self.visual_encoder_list[0].load_state_dict(torch.load("/home/haoran/Yanzhaoshi/MiniGPT-4/minigpt4/output/minigpt4_stage2_finetune/20240919232/checkpoint_59.pth"))
def freeze_model(self):
visual_num=0
llama_num=0
for n,p in self.named_parameters():
# p.requires_grad=False
if 'visual_encoder' in n or 'llama_proj_mlps' in n or 'ln_vision' in n:
# print(n)
p.requires_grad=True
visual_num+=p.numel()
else:
llama_num+=p.numel()
p.requires_grad=False
self.grad_list.append(p)
def vit_to_cpu(self):
# self.ln_vision.to("cpu")
# self.ln_vision.float()
self.visual_encoder.to("cpu")
self.visual_encoder.float()
# 需要修改点1 tokenizer.bos_token组要在最前面,是llama3的开始符
# 点2 在instruction里加入模态的名称和分隔符
def get_context_emb(self, prompt, img_list):
device = img_list[0].device
prompt_segs = prompt.split('<ImageHere>') # 分成两部分
seg_tokens = [
self.llama_tokenizer(
seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids # only add bos to the first seg
for i, seg in enumerate(prompt_segs)
]
seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens]
# for idx, seg in enumerate(p_segs[:-1]): # 遍历全部的按照'<ImageHere>'分隔后的部分,除了最后面的一项
# # 提取当前子部分的token id 并编码
# p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
# p_embed = self.embed_tokens(p_tokens.input_ids)
# # 累加器加上当前seg中的token数量
# count_current_len += p_embed.size(1)
# # volume visual token 开始位
# self.start_visual_token_idx.append(count_current_len)
# # 在这个子部分文本后,插入当前volume的视觉特征
# interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * vis_chunk_size:(idx + 1) * vis_chunk_size]], dim=1))
# # 累加器加上当前vis_chunk中的token数量
# count_current_len += vis_chunk_size
# # volume visual token 结束位置
# self.end_visual_token_idx.append(count_current_len)
count_current_len=1
old_version=1
if old_version==0:
mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
# 2024.09.07 Yanzhaoshi修改 加入llama的起始符到句首 self.llama_tokenizer.bos_token
#llama:第一个字符是bos开始符
bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids)
# mixed_embs = torch.cat([bos_embed, mixed_embs], dim=1)
# self.start_visual_token_idx=[2]
# self.end_visual_token_idx=[3]
# self.llama_model.set_index(self.start_visual_token_idx,self.end_visual_token_idx)
eos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids)
mixed_embs = torch.cat([bos_embed, mixed_embs,eos_embed], dim=1)
else:
image_list=[]
cur_att=1
if cur_att==0:
for index in range(1,5):
image_list.append(img_list[0][:,32*(index-1):32*index])
mixed_embs = [emb for pair in zip(seg_embs[:-1], image_list) for emb in pair] + [seg_embs[-1]]
#######################
# mixed_embs=[]
# for pair in zip(seg_embs[:-1], image_list):
# for emb in pair:
# mixed_embs.append(emb)
mixed_embs=mixed_embs+[seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids)
mixed_embs = torch.cat([bos_embed, mixed_embs], dim=1)
# eos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids)
# mixed_embs = torch.cat([bos_embed, mixed_embs,eos_embed], dim=1)
elif cur_att==1:
image_list=[]
# 记录每个volume token的结束位置 example [34, 67, 100, 133]
for index in range(1,5):
image_list.append(img_list[0][:,32*(index-1):32*index])
# mixed_embs = [emb for pair in zip(seg_embs[:-1], image_list) for emb in pair] + [seg_embs[-1]]
count_current_len = 1 # 计算当前prompt长度,取1是因为有开始符,开始符在循环外边会追加。用于找是视觉token的位置
self.start_visual_token_idx = [] # 记录每个volume token的开始位置 example [2, 35, 68, 101]
self.end_visual_token_idx = []
mixed_embs=[]
for pair in zip(seg_embs[:-1], image_list):
count_current_len+=1
self.start_visual_token_idx.append(count_current_len)
for emb in pair:
# volume visual token 结束位置
mixed_embs.append(emb)
count_current_len += 32
self.end_visual_token_idx.append(count_current_len)
self.llama_model.set_index(self.start_visual_token_idx,self.end_visual_token_idx)
mixed_embs=mixed_embs+[seg_embs[-1]]
mixed_embs = torch.cat(mixed_embs, dim=1)
bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids)
eos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids)
# mixed_embs = torch.cat([bos_embed, mixed_embs], dim=1)
mixed_embs = torch.cat([bos_embed, mixed_embs,eos_embed], dim=1)
return mixed_embs
def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None): # 拼接prompts
if prompts is None or len(prompts) == 0:
# prompts is not provided, just return the original image embedding
return img_embeds, atts_img
elif img_embeds is None:
# prompt is provided but there is no image embedding. return the prompt embedding in right padding
self.llama_tokenizer.padding_side = "right"
prompt_tokens = self.llama_tokenizer(
prompts,
return_tensors="pt",
padding="longest",
add_special_tokens=False
).to(self.device)
prompt_embeds = self.embed_tokens(prompt_tokens.input_ids)
atts_prompt = prompt_tokens.attention_mask
return prompt_embeds, atts_prompt
else: # 走这里
# return the multi-modal embedding in right padding
emb_lists = []
if isinstance(prompts, str):
prompts = [prompts] * len(img_embeds)
self.start_visual_token_idx = None
self.end_visual_token_idx = None
cond_ids_list=[]
for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)):
pn = each_img_embed.shape[-2] # 图像总个数 128
# pn = each_img_embed.shape[-1]
if lengths is not None:
each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1])
each_img_embed = each_img_embed[:lengths[idx] * pn]
p_segs = each_prompt.split('<ImageHere>') # 把token的左右内容分开,插入图像特征
interleave_emb = []
interleave_ids=[]
'''2024.09.08 Yanzhaoshi修改: 以32为一组, 拆分视觉特征 <bos> <Img> <ImageHere> <t> <ImageHere> <t> <ImageHere> <t> <ImageHere> </Img> {instruction} '''
count_current_len = 1 # 计算当前prompt长度,取1是因为有开始符,开始符会在后面追加。该变量用于找是视觉token的位置
self.start_visual_token_idx = [] # 记录每个volume token的开始位置 example [2, 35, 68, 101]
self.end_visual_token_idx = [] # 记录每个volume token的结束位置 example [34, 67, 100, 133]
if '<t>' in each_prompt: # 如果检测到了分隔符<t>,就对视觉特征按照volume进行分块插入,每块都会用<t>隔开
# 2024.09.09 计算每个volume起始和终点位置,确保后面LLM中casual mask的位置正确
# instruction示例 <Img> <ImageHere> <t> <ImageHere> <t> <ImageHere> <t> <ImageHere></Img> 后面加入了对应模态的名称在<ImageHere>前面
vis_chunk_size = 32 # 一个modality的图像个数,作为一个块
for idx, seg in enumerate(p_segs[:-1]): # 遍历全部的按照'<ImageHere>'分隔后的部分,除了最后面的一项
# 提取当前子部分的token id 并编码
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
#10.13 update ids
interleave_ids += p_tokens.input_ids.squeeze(0).tolist()
p_embed = self.embed_tokens(p_tokens.input_ids)
# 累加器加上当前seg中的token数量
count_current_len += p_embed.size(1)
# volume visual token 开始位
self.start_visual_token_idx.append(count_current_len)
# 在这个子部分文本后,插入当前volume的视觉特征
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * vis_chunk_size:(idx + 1) * vis_chunk_size]], dim=1))
interleave_ids += self.llama_tokenizer('<t>', add_special_tokens=False).input_ids * (int(pn/4))
# 累加器加上当前vis_chunk中的token数量
count_current_len += vis_chunk_size
# volume visual token 结束位置
self.end_visual_token_idx.append(count_current_len)
else: # 如果没检测到分隔符<t>,就还是用以前的代码,把视觉特征全部一起插入 2D 数据会走这里
for idx, seg in enumerate(p_segs[:-1]): # 遍历全部的按照'<ImageHere>'分隔后的部分,除了最后面的一项
# 提取当前子部分的token id 并编码
p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device)
p_embed = self.embed_tokens(p_tokens.input_ids)
# 累加器加上当前seg中的token数量
count_current_len += p_embed.size(1)
# volume visual token 开始位
self.start_visual_token_idx.append(count_current_len)
# 以前这种直接插入全部图像的方法
interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1))
# 累加器加上当前vis_chunk中的token数量
count_current_len += pn
# volume visual token 结束位置
self.end_visual_token_idx.append(count_current_len)
wrapped_emb = torch.cat(interleave_emb, dim=1) # 将前面每块的特征合并起来
# 2024.09.07 Yanzhaoshi修改 加入llama的起始符到句首 self.llama_tokenizer.bos_token
# 2024.09.23 Yanzhaoshi修改 取消该部分代码,因为后面有了bos,避免加重了
# bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(img_embeds.device).input_ids)
# wrapped_emb = torch.cat([bos_embed, wrapped_emb], dim=1)
p_tokens = self.llama_tokenizer(
p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device) # 拼接上最后一块instruction的特征
interleave_ids += p_tokens.input_ids.squeeze(0).tolist()
cond_ids_list.append(interleave_ids)
cond_ids_list = torch.tensor(cond_ids_list).to(img_embeds.device)
#10.13 update
p_embed = self.embed_tokens(p_tokens.input_ids)
wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1) # 在[1, 131, 4096]后面加上instruction文本[1, 32, 4096]
# 2024.09.11 Yanzhaoshi 修改 加入eos终止符 self.llama_tokenizer.eos_token
# 2024.10.13 之后训练时去掉这里
# eos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(img_embeds.device).input_ids)
# wrapped_emb = torch.cat([wrapped_emb, eos_embed], dim=1)
# cond_ids_list = torch.cat([cond_ids_list, torch.tensor(self.llama_tokenizer.eos_token_id).cuda().unsqueeze(0).unsqueeze(0)], dim=1)
emb_lists.append(wrapped_emb) # 合并后 [1, 163, 4096]
emb_lens = [emb.shape[1] for emb in emb_lists]
pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) # 计算padding token的embedding [4096]
max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len # 最大长度max_length 163
wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone() # [1, 163, 4096]
wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) # [1, 163]
for i, emb in enumerate(emb_lists):
length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len # 163
wrapped_embs[i, :length] = emb[:, :length] # 两者是相同的
wrapped_atts[i, :length] = 1
return wrapped_embs, wrapped_atts,cond_ids_list
def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts):
"""
Concatenate the batched input embedding and batched output embedding together.
Both the input and the output embedding should be right padded.
"""
input_lens = []
cat_embs = []
cat_atts = []
for i in range(input_embs.size(0)): # input_embs [1, 163, 4096] output_embs [1, 30, 4096]
input_len = input_atts[i].sum()
input_lens.append(input_len)
cat_embs.append(
torch.cat([
input_embs[i][:input_len],
output_embs[i],
input_embs[i][input_len:]
])
) # cat_embs [193, 4096]
cat_atts.append(
torch.cat([
input_atts[i][:input_len],
output_atts[i],
input_atts[i][input_len:]
])
) # cat_atts [193]
cat_embs = torch.stack(cat_embs) # [1, 193, 4096]
cat_atts = torch.stack(cat_atts) # [1, 193]
return cat_embs, cat_atts, input_lens
def tokenize_conversation(self, conv_q, conv_a):
"""concatenate conversation and make sure the model is only trained to regress the answer"""
to_regress_token_ids_list = []
targets_list = []
batch_size = len(conv_q)
for batch_idx in range(batch_size):
questions, answers = conv_q[batch_idx], conv_a[batch_idx]
questions = [self.llama_tokenizer(self.llama_tokenizer.bos_token + q,
return_tensors="pt",
add_special_tokens=False).to(self.device) for q in questions[1:]] # the first question is handled in the prompt wrap function, skip it
answers = [self.llama_tokenizer(a + self.end_sym,
return_tensors="pt",
add_special_tokens=False).to(self.device) for a in answers]
cur_id = []
cur_target = []
for i in range(len(questions)):
cur_id.append(answers[i].input_ids)
cur_target.append(answers[i].input_ids)
cur_id.append(questions[i].input_ids)
cur_target.append(torch.ones_like(questions[i].input_ids) * -100)
cur_id.append(answers[-1].input_ids)
cur_target.append(answers[-1].input_ids)
cur_id = torch.cat(cur_id, dim=1)
cur_target = torch.cat(cur_target, dim=1)
to_regress_token_ids_list.append(cur_id)
targets_list.append(cur_target)
max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len)
to_regress_token_ids = torch.ones([batch_size, max_len],
dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id
targets = torch.ones([batch_size, max_len],
dtype=cur_id.dtype, device=self.device) * -100
for batch_idx in range(batch_size):
cur_len = to_regress_token_ids_list[batch_idx].shape[1]
to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len]
targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len]
to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int)
return to_regress_token_ids, to_regress_token_attn, targets
def preparing_embedding(self, samples):
### prepare input tokens
if 'several_modalities_image' in samples:
# print(samples['several_modalities_image'])
images=samples['several_modalities_image']['images']
# print(images.shape)
# print(len(images))
modalities=samples['several_modalities_image']['modalities']
# print(modalities)
device=images[0].device
temp=[]
for index,image in enumerate(images):
# print(image.shape)
# img_embeds, img_atts,self_loss = self.encode_img(image,modalities[index])
# img_embeds, img_atts,self_loss=self.encode_img(image,modalities[index])
img_embeds, img_atts,self_loss=self.encode_img(image)
# print(img_embeds.shape)
temp.append(img_embeds)
img_embeds=torch.cat(temp, dim=1)
# print()
elif 'image' in samples:
if type(samples["image"])==list:
images = torch.stack(samples["image"]).to(self.device)
# print(images.shape)
else:
images=samples["image"].to(self.device)
# print(samples["image"])
# images=samples["image"].to(self.device)
img_embeds, img_atts, self_loss = self.encode_img(images, "LR_Encoding") #LR 图像的维度是[b, 4(模态数), 32(切片数), 3(初始通道数), 224(长), 224(宽)] torch.Size([1, 7, 32, 3, 224, 224])
# img_embeds [128, 1, 768] img_atts [128, 1] (全1矩阵)
if "HR_resolution" in samples.keys() and samples['HR_resolution'] ==True: #HR (0.1,0.07,nan,nan)
HR_img_embeds, HR_img_atts, HR_self_loss = self.encode_img(samples["HR_image_list"].to(self.device), "HR_Encoding") #HR [224*196*4, 1, 768]
#
# print(HR_img_embeds.shape,img_embeds.shape)
img_embeds=self.cross_attention(img_embeds.reshape(1,-1,768), HR_img_embeds.reshape(1,-1,768)) # 1, 128, 768
# print('----------------')
# print(img_embeds.shape)
# 修改1 在此处加上projector,映射到LLM维度4096
img_embeds=self.llama_proj_mlps(img_embeds) # 1, 128, 4096
device=images.device
else:
img_embeds = img_atts = None
# print(img_embeds.shape)
if 'conv_q' in samples:
# handeling conversation datasets
conv_q, conv_a = samples['conv_q'], samples['conv_a']
connect_sym = samples['connect_sym'][0]
conv_q = [q.split(connect_sym)for q in conv_q]
conv_a = [a.split(connect_sym) for a in conv_a]
conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q]
cond_embeds, cond_atts,_ = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q])
regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a)
else:
if "instruction_input" in samples:
instruction = samples["instruction_input"] # ["<Img><ImageHere></Img> [caption] There are several MRI sequence from 1 patient. Please make a detailed diagnosis for this patient including the tumor's detailed type. "]
elif self.prompt_list:
instruction = random.choice(self.prompt_list)
else:
instruction = None
if hasattr(self, 'chat_template') and self.chat_template:
instruction = [self.prompt_template.format(instruct) for instruct in instruction]
if 'length' in samples:
# the input is a image train (like videos)
bsz, pn, hs = img_embeds.shape
# print(img_embeds.shape)
img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs)
cond_embeds, cond_atts,_ = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length'])
else:
# return self_loss
# print(img_embeds.shape)#(a,32,4096)
# img_embeds=self.cross_attention(img_embeds)
# print('img_embeds_0',img_embeds.shape)
# img_embeds=img_embeds.reshape(1,-1,*img_embeds.shape[-1:])#32,32,
# print('img_embeds_1',img_embeds.shape)
if img_embeds.size(0)>8 and self.has_qformer==True:
img_embeds=img_embeds.reshape(img_embeds.size(0),img_embeds.size(1),64,64)
# img_embeds=self.spatial_pooling_layer_2d(img_embeds)
# print('img_embeds',img_embeds.shape)
img_embeds=img_embeds.permute(1,0,2,3)
img_embeds=self.conv3d_layer(img_embeds)
img_embeds=img_embeds.reshape(img_embeds.size(0),img_embeds.size(1),4096)
img_embeds=self.conv3d_proj(img_embeds).reshape(1,-1,4096)
else: # 走这里
img_embeds=img_embeds.reshape(1,-1,4096) #1次输入很多张图-1 instruction
# print('img_embeds',img_embeds.shape)
# print('after downsampling',img_embeds.shape)
# print(img_embeds.shape)#(1,a*32,4096)
# 将multi-modal prompt进行拼接
cond_embeds, cond_atts,_ = self.prompt_wrap(img_embeds, img_atts, instruction) # cond_embeds [1, 163, 4096] cond_atts [1, 163]
### prepare target tokens
self.llama_tokenizer.padding_side = "right"
text = [t + self.end_sym for t in samples["answer"]] # target (self.end_sym = ###) 'This patient was diagnosed with Brain Metastases Tumor, and further diagnosed as Brain Metastases Tumor###'
# print(f"text {text}")
regress_tokens = self.llama_tokenizer(
text,
return_tensors="pt",
padding="longest",
truncation=True,
max_length=self.max_txt_len,
add_special_tokens=False
).to(self.device) # regress is target (对gt report进行tokenize) torch.Size([1, 30])
# print(regress_tokens.shape)
regress_token_ids = regress_tokens.input_ids # [1, 30]
# print("regress_token_ids:",regress_token_ids)
regress_atts = regress_tokens.attention_mask # [1, 30]
part_targets = regress_token_ids.masked_fill(
regress_token_ids == self.llama_tokenizer.pad_token_id, -100
) # 不关注填充的token [1, 30]
# print(part_targets)
# print(part_targets.shape)
if 'class_target' in samples:
class_target=samples['class_target']
class_target=class_target.to(device)
else:
class_target=None
regress_embeds = self.embed_tokens(regress_token_ids) # gt report embedding [1, 30, 4096]
return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets, img_embeds, class_target, self_loss
def forward(self, samples, reduction='mean'):
# prepare the embedding to condition and the embedding to regress
# not_load=True
if self.self_training == False:
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets,input_embs,class_target,self_loss = \
self.preparing_embedding(samples) #
else:
self.visual_encoder=self.visual_encoder_list[0]
image=samples['image'].to(self.device)
if len(image.shape) > 4:
image = image.reshape(-1, *image.shape[-3:])
with self.maybe_autocast():
loss,_,_,latent=self.visual_encoder(image)
# self_loss=self.encode_img(samples)
return {"loss": loss}
# print(samples['bounding_box'])
# if 'bounding_box' in samples and samples['bounding_box']!=['None']:
# label=samples['bounding_box']
# # bounding_box_loss_fn=iou_loss()
# # x=input_embs
# # x=self.decoder_linear(x)
# # for blk in self.decoder_blocks:
# # x = blk(x)
# # x = self.decoder_norm(x)
# # # predictor projection
# # x = self.decoder_pred(x.view(1,-1))
# # # print(input_embs.view(1,-1).shape)
# bounding_box=self.boundinb_box_relu(self.bounding_box_layer(input_embs.view(1,-1)))
# loss_fn=nn.SmoothL1Loss()
# # print('bounding box',bounding_box[0])
# # print('label',label[0])
# # print('bounding box',bounding_box)
# # print('label',label)
# iou_loss=calculate_iou(label[0],bounding_box[0])
# giou_loss=loss_fn(label,bounding_box)
# bounding_box_loss=giou_loss
# else:
# iou_loss=1
# bounding_box_loss=None
# print(cond_embeds.shape)
# concat the embedding to condition and the embedding to regress
# print('class target is ',class_target)
if class_target!=None:
img_num=input_embs.shape[1]/32
# class_predict=self.class_layer_list[int(img_num)-1](input_embs.view(input_embs.size(0), -1))
classification_loss_fn=torch.nn.CrossEntropyLoss()
# print('predict shape is',class_predict.shape)
# print('target shape is ',class_target.shape)
class_loss=classification_loss_fn(class_predict,class_target)
else:
class_loss=None
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) # 把instruction和target concat
# get bos token embedding
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos) # [1, 1, 4096]
bos_atts = cond_atts[:, :1] # [1, 1] 值为1
# # add bos token at the begining
# print(bos_embeds.shape,inputs_embeds.shape)
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) # [1, 173, 4096]
attention_mask = torch.cat([bos_atts, attention_mask], dim=1) # [1, 173] 全1
# ensemble the final targets
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(self.device).fill_(-100) #
for i, target in enumerate(part_targets): # 将前面的token id变成-100,后面gt的index还是对应的
targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target # plus 1 for bos
# print(targets)
# print(inputs_embeds.shape)
# mask=samples['mask']
# if bounding_box_loss==Tr:
# print('self_loss',self_loss)
# print(targets.shape)
if self_loss==None:
with self.maybe_autocast(): # 将multimodal prompt输入至LLM
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
start_visual_idx = self.start_visual_token_idx,
end_visual_idx = self.end_visual_token_idx,
return_dict=True,
labels=targets,
reduction=reduction
)
hidden_states=outputs.hidden_states[-1] # outputs.hidden_states [1, 173, 4096] hidden_states [173, 4096]
# class_predict=self.class_layer(hidden_states[:,:32,:].view(1,-1))
# print(outputs.keys())
# print(hidden_states.shape)
# if self_loss:
# loss=self_loss
if "bounding_box" in samples:
gt_bounding_box=samples["bounding_box"]
# print(self.bounding_box_label)
# print(targets.shape)
# print(self.location_label)
output_token_index = (targets == self.bounding_box_label).nonzero()
# print(targets)
if len(output_token_index)>0:
location_state=hidden_states[output_token_index[0][-1]]
bounding_box=self.seg_layer(location_state.view(1, -1))
# print(gt_bounding_box[0])
# print('predict',bounding_box)
# print('gt',gt_bounding_box)
# print(gt_bounding_box)
L1_loss=self.L1_loss_fn(bounding_box,gt_bounding_box)
# print(bounding_box,gt_bounding_box)
giou_loss=bbox_giou_loss(bounding_box,gt_bounding_box)
# print(1-giou_loss)
if giou_loss>0:
# print(L1_loss)
loss=outputs.loss+0.2*L1_loss+1.2*giou_loss
else:
loss=0.8*outputs.loss+0.2*L1_loss
else:
loss=outputs.loss
# print(giou_loss)
# print(output_token_index)
# if len(output_token_index):
# addon_index = torch.ones_like(output_token_index)*(-1)
# addon_index[:, 0] = 0
# output_token_index += addon_index
# print(output_token_index)
# class_predict
# loss = class_loss
# elif bounding_box_loss:
# loss=bounding_box_loss
else:
loss=outputs.loss
# print("loss is ",loss)
# loss=outputs.loss
# if self_loss:
# # print('sssssssssssss')
# loss=loss+0.25*self_loss
if torch.isnan(loss):
print(samples['image_name'])
return {"loss": loss}
def embed_tokens(self, token_ids):
if hasattr(self.llama_model.base_model, 'model'): ## lora wrapped model
embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids)
else:
embeds = self.llama_model.base_model.embed_tokens(token_ids)
return embeds
@torch.no_grad()
def multi_generate(
self,
samples,
texts,
num_beams=1,
max_new_tokens=160,
min_length=1,
top_p=0.9,
repetition_penalty=1,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
# images = torch.stack(images, dim=1)
# img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device))
if 'several_modalities_image' in samples:
images=samples['several_modalities_image']['images']
# print(len(images))
modalities=samples['several_modalities_image']['modalities']
# print(modalities)
# device=images[0].device
temp=[]
for index,image in enumerate(images):
# print(image.shape)
# img_embeds, img_atts,self_loss = self.encode_img(image,modalities[index])
img_embeds, img_atts,self_loss = self.encode_img(image.to(self.device).half(),modalities[index])
# print(img_embeds.shape)
temp.append(img_embeds)
img_embeds=torch.cat(temp, dim=1)
image_lists = [[image_emb[None]] for image_emb in img_embeds]
# print(len(image_lists),len(texts))
# for i in image_lists:
# print(i[0].shape)
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
with self.maybe_autocast():
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty,
temperature=temperature,
do_sample=do_sample,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
stopping_criteria=stopping_criteria,
)
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
return answers
@torch.no_grad()
def generate(
self,
images,
texts,
num_beams=1,
max_new_tokens=300,
min_length=1,
top_p=0.9,
repetition_penalty=1,
length_penalty=1,
temperature=1, # 可以选择0.1, 0.2, 0.4, 0.6, 0.8, 1
do_sample=False,
stop_words_ids=[2],
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
# images = torch.stack(images, dim=1)
img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device))
image_lists = [[image_emb[None]] for image_emb in img_embeds]
# print(len(image_lists),len(texts))
# for i in image_lists:
# print(i[0].shape)
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
with self.maybe_autocast():
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty,
temperature=temperature,
do_sample=do_sample,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
stopping_criteria=stopping_criteria,
)
answers = []
for output_token in outputs:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
return answers
# def single_predict_class(self,samples,device):
# images=samples['image']
# img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device))
# class_predict=self.class_layer_list[0](img_embeds.view(1, -1).float())
# return class_predict
def predict_class(self,samples,device):
# img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device))
# print(type(img_embeds))
# print(img_embeds.shape)
images=samples['several_modalities_image']['images']
# print(len(images))
modalities=samples['several_modalities_image']['modalities']
# print(modalities)
# device=images[0].device
temp=[]
num=len(images)
for index,image in enumerate(images):
# print(image.shape)
# img_embeds, img_atts,self_loss = self.encode_img(image,modalities[index])
img_embeds, img_atts,self_loss=self.encode_img(image.to(device).half(),modalities[index])
# print(img_embeds.shape)
temp.append(img_embeds)
img_embeds=torch.cat(temp, dim=1)
# class_predict=self.class_layer_list[num-1](img_embeds.view(1, -1).float())
# if class_target!=None:
# classification_loss_fn=torch.nn.CrossEntropyLoss()
# class_loss=classification_loss_fn(class_predict,class_target)
return class_predict
@torch.no_grad()
def predict_boundingbox(
self,
images,
texts,
HR_images=None,
num_beams=1,
max_new_tokens=300,
min_length=1,
top_p=0.9,
repetition_penalty=1,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
output_hidden_states=True
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
# images = torch.stack(images, dim=1)
img_embeds, atts_img, self_loss = self.encode_img(images.to(self.device), "LR_Encoding")
if HR_images!=[]:
# print('-----------')
HR_embeds,HR_atts_img,HR_loss = self.encode_img(HR_images.to(self.device), "HR_Encoding")
img_embeds = self.cross_attention(img_embeds.reshape(1,-1,768).float(),HR_embeds.reshape(1,-1,768).float())
# 修改1 在此处加上projector,映射到LLM维度4096
img_embeds=self.llama_proj_mlps(img_embeds) # 1, 128, 4096
img_embeds=img_embeds.reshape(1,-1,*img_embeds.shape[-1:])
if img_embeds.size(0)>8:
img_embeds=img_embeds.reshape(1,-1,4096)
# print(img_embeds.shape)
# img_embeds=self.conv3d_proj(img_embeds)
else:
img_embeds=img_embeds.reshape(1,-1,4096)
# print(img_embeds.shape)
image_lists = [[image_emb[None]] for image_emb in img_embeds]
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
# with self.maybe_autocast():
# outputs = self.llama_model.generate(
# inputs_embeds=embs,
# attention_mask=attn_mask,
# max_new_tokens=max_new_tokens,
# num_beams=num_beams,
# length_penalty=length_penalty,
# temperature=temperature,
# do_sample=do_sample,
# min_length=min_length,
# top_p=top_p,
# repetition_penalty=repetition_penalty,
# stopping_criteria=stopping_criteria,
# )
with self.maybe_autocast():
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty,
temperature=temperature,
do_sample=do_sample,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
stopping_criteria=stopping_criteria,
return_dict_in_generate=True,
output_hidden_states=output_hidden_states,
pad_token_id=2,
eos_token_id=2
)
hidden_states=outputs['hidden_states']
# [39,38,66177,1163,44,60158,40249,47,51,34,39]
answers = []
top_2_list=self.llama_model.get_top()
if top_2_list[0]!=None and top_2_list[1]!=None:
top2_answer=self.llama_tokenizer.decode(top_2_list, skip_special_tokens=True)
# print(top2_answer)
self.llama_model.delete_max_id()
for output_token in outputs['sequences']:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
'''
打印了一下output_texts,为什么会这么长,不是应该生成完前面的答案之后,自己预测出结束符然后结束吗?换句话说后面应该都是结束符才对,而结束符会在decoder时skip掉
'This patient was diagnosed with Embryonal tumours### Short diagnosis: Embryonal tumours### Diagnosis: MedulloblastomaThis patient was diagnosed
with Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma###
Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal
tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis:
Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma###
Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal
tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis:
Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma###
Diagnosis: Embryonal tum'
后面几行代码看起来像处理结束符的,训练时有加上结束符吗
'''
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
# print(output_texts)
# print('answers',answers)
return answers
def get_input_embeddings(self):
return self.llama_model.get_input_embeddings()
@torch.no_grad()
def generate_step(
self,
images,
texts,
HR_images=None,
num_beams=1,
echo=False,
max_new_tokens=500,
min_length=1,
top_p=0.9,
repetition_penalty=1,
length_penalty=1,
temperature=-1, # -1 0.01 0.05 0.1 0.2 0.4 0.6 0.8 1
do_sample=False,
stop_words_ids=[2],
output_hidden_states=True
):
'''
function for generate test use
'''
# breakpoint()
# self.llama_model.to(self.device)
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
# images = torch.stack(images, dim=1)
# breakpoint()
img_embeds, atts_img, self_loss = self.encode_img(images.to(self.device), "LR_Encoding")
if HR_images!=[]:
# print('-----------')
HR_embeds,HR_atts_img,HR_loss = self.encode_img(HR_images.to(self.device), "HR_Encoding")
img_embeds = self.cross_attention(img_embeds.reshape(1,-1,768).float(),HR_embeds.reshape(1,-1,768).float())
# 修改1 在此处加上projector,映射到LLM维度4096
img_embeds=self.llama_proj_mlps(img_embeds) # 1, 128, 4096
bsz=img_embeds.shape[0]
img_embeds=img_embeds.reshape(1,-1,*img_embeds.shape[-1:])
# print(f'img_embeds{img_embeds.shape}')
if img_embeds.size(0)>8:
img_embeds=img_embeds.reshape(1,-1,4096)
# print(img_embeds.shape)
# img_embeds=self.conv3d_proj(img_embeds)
else:
img_embeds=img_embeds.reshape(1,-1,4096)
# print(img_embeds.shape)
# image_lists = [[image_emb[None]] for image_emb in img_embeds]
# batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
inputs_embeds,attention_mask,input_ids=self.prompt_wrap(img_embeds,atts_img,texts)
# get bos token embedding
bos = torch.ones_like(attention_mask[:, :1]) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos) # [1, 1, 4096]
bos_atts = attention_mask[:, :1] # [1, 1] 值为1
# add bos token at the begining
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) # [1, 173, 4096]
# print(f'inputs_embeds{inputs_embeds.shape}')
attention_mask = torch.cat([bos_atts, attention_mask], dim=1) # [1, 173] 全1
input_ids = torch.cat([bos, input_ids], dim=1)
# bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(batch_embs.device).input_ids)
# batch_embs = torch.cat([bos_embed, batch_embs], dim=1)
# batch_embs=[batch_embs]
batch_size = len(inputs_embeds)
# min_len = min([emb.shape[1] for emb in inputs_embeds])
# max_len = max([emb.shape[1] for emb in inputs_embeds])
emb_dim = inputs_embeds.shape[2]
dtype = inputs_embeds.dtype
device = inputs_embeds.device
min_prompt_len = min(len(t) for t in inputs_embeds) # 84
max_prompt_len = max(len(t) for t in inputs_embeds) # 84
total_len = max_new_tokens + min_prompt_len
#############
# embeds = torch.full((bsz, total_len, inputs_embeds.shape[-1]), 0, dtype=torch.float, device="cuda") # 30 339 4096
# for k, t in enumerate(inputs_embeds):
# embeds[k, : len(t)] = t
# for k, t in enumerate(input_ids):
# tokens[k, : len(t)] = t
###########
self.pad_token_id=self.llama_tokenizer.pad_token_id
pad_id = self.pad_token_id
tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.int, device=device)
embs = torch.zeros([batch_size, total_len, emb_dim], dtype=dtype, device=device)
for k, t in enumerate(inputs_embeds):
embs[k, : len(t)] = t
for k, t in enumerate(input_ids):
tokens[k, : len(t)] = t
prev_pos=0
eos_reached = torch.tensor([False] * bsz, device=self.device)
input_text_mask = tokens != pad_id
self.stop_token_id = int(self.llama_tokenizer.encode("###")[-1])
stop_tokens = torch.tensor([self.stop_token_id], device=self.device)
# attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
# for i, emb in enumerate(batch_embs):
# emb_len = emb.shape[1]
# embs[i, -emb_len:] = emb[0]
# attn_mask[i, -emb_len:] = 1
# batch_embs=batch_embs[0]
# eos_token=self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(device).input_ids[0]
# prev_pos=0
past_key_values=None
# self.stop_token_id = int(self.llama_tokenizer.encode("###")[-1])
# stop_tokens = torch.tensor([self.stop_token_id], device="cuda")
# eos_reached = torch.tensor([False] * bsz, device="cuda")
with_probs = False
calculate_prob=False
# breakpoint()
for cur_pos in range(min_prompt_len, total_len):
with self.maybe_autocast():
outputs = self.llama_model(
# attention_mask=attention_mask[:, prev_pos:cur_pos],
# position_ids=position_ids,
past_key_values=past_key_values,
inputs_embeds=embs[:, prev_pos:cur_pos],
use_cache=True,
output_hidden_states=output_hidden_states,
return_dict=True,
start_visual_idx = self.start_visual_token_idx,
end_visual_idx = self.end_visual_token_idx,
)
logits = outputs["logits"] # 1 176 128256
past_key_values = outputs["past_key_values"]
if temperature > 0:
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
next_token = sample_top_p(probs, top_p)
else:
next_token = torch.argmax(logits[:, -1], dim=-1)
next_token = next_token.reshape(-1)
# only replace token if prompt has already been generated
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
import copy
if calculate_prob:
probs = torch.softmax(logits[:, -1], dim=-1)
index_list=[15,16,17,18,19,20,21,22,23,24,605,806,717,1032]
new_logits=copy.deepcopy(logits[0][0])
tumor_list=[new_logits[k] for k in index_list]
tumor_list = torch.stack(tumor_list)
probabilities = torch.softmax(tumor_list, dim=0)
token_dict={probabilities[k]:index_list[k] for k in range(len(index_list))}
top_2=sorted(token_dict.keys(), reverse=True)[:2]
# print(top_2[0].item(),top_2[1].item())
top1=token_dict[top_2[0]]
top2=token_dict[top_2[1]]
print(f'<class_{self.llama_tokenizer.decode(top1)}>',top_2[0].item(),f'<class_{self.llama_tokenizer.decode(top2)}>',top_2[1].item())
with_probs=False
calculate_prob=False
if next_token==449:
with_probs=True
if with_probs:
if next_token==62:
calculate_prob=True
tokens[:, cur_pos] = next_token
next_token_embed = self.llama_model.get_input_embeddings()(next_token)
embs[:, cur_pos] = next_token_embed
# breakpoint()
eos_reached |= (~input_text_mask[:, cur_pos]) & (
torch.isin(next_token, stop_tokens.to(self.device))
)
prev_pos = cur_pos
if all(eos_reached):
break
# out_tokens, out_logprobs = [], []
##############################
# for i, toks in enumerate(tokens.tolist()):
# # cut to max gen len
# start = 0 if echo else len(input_ids[i])
# toks = toks[start: len(input_ids[i]) + max_gen_len]
# probs = None
# # if logprobs:
# # probs = token_logprobs[i][start : len(input_ids[i]) + max_gen_len]
# # cut to after eos tok if any
# for stop_token in [self.stop_token_id]:
# try:
# eos_idx = toks.index(stop_token)
# toks = toks[:eos_idx]
# # probs = probs[:eos_idx] if logprobs else None
# except ValueError:
# pass
# out_tokens.append(toks)
# # out_logprobs.append(probs)
# # return (out_tokens, out_logprobs if logprobs else None)
# return out_tokens
# hidden_states=outputs['hidden_states']
# # [39,38,66177,1163,44,60158,40249,47,51,34,39]
# answers = []
# # top_2_list=self.llama_model.get_top()
# # if top_2_list[0]!=None and top_2_list[1]!=None:
# # top2_answer=self.llama_tokenizer.decode(top_2_list, skip_special_tokens=True)
# # # print(top2_answer)
# # self.llama_model.delete_max_id()
# for output_token in tokens:
# if output_token[0] == 0:
# output_token = output_token[1:]
# output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
# output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
# output_texts = output_texts.replace("<s>", "")
# output_texts = output_texts.split(r'[/INST]')[-1].strip()
# answers.append(output_texts)
# # print(output_texts)
out_tokens, out_logprobs = [], []
for i, toks in enumerate(tokens.tolist()):
# cut to max gen len
start = 0 if echo else len(input_ids[i])
toks = toks[start: len(input_ids[i]) + max_new_tokens]
probs = None
# if logprobs:
# probs = token_logprobs[i][start : len(input_ids[i]) + max_gen_len]
# cut to after eos tok if any
for stop_token in [self.stop_token_id]:
try:
eos_idx = toks.index(stop_token)
toks = toks[:eos_idx]
# probs = probs[:eos_idx] if logprobs else None
except ValueError:
pass
out_tokens.append(toks)
answers = self.llama_tokenizer.decode(out_tokens[0])
return answers
def generate_segmentation(
self,
images,
texts,
num_beams=1,
max_new_tokens=300,
min_length=1,
top_p=0.9,
repetition_penalty=1,
length_penalty=1,
temperature=1,
do_sample=False,
stop_words_ids=[2],
output_hidden_states=True
):
'''
function for generate test use
'''
stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(
stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])])
# images = torch.stack(images, dim=1)
img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device))
image_lists = [[image_emb[None]] for image_emb in img_embeds]
# print(len(image_lists),len(texts))
# for i in image_lists:
# print(i[0].shape)
batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)]
batch_size = len(batch_embs)
max_len = max([emb.shape[1] for emb in batch_embs])
emb_dim = batch_embs[0].shape[2]
dtype = batch_embs[0].dtype
device = batch_embs[0].device
embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device)
attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device)
for i, emb in enumerate(batch_embs):
emb_len = emb.shape[1]
embs[i, -emb_len:] = emb[0]
attn_mask[i, -emb_len:] = 1
with self.maybe_autocast():
outputs = self.llama_model.generate(
inputs_embeds=embs,
attention_mask=attn_mask,
max_new_tokens=max_new_tokens,
num_beams=num_beams,
length_penalty=length_penalty,
temperature=temperature,
do_sample=do_sample,
min_length=min_length,
top_p=top_p,
repetition_penalty=repetition_penalty,
stopping_criteria=stopping_criteria,
return_dict_in_generate=True,
output_hidden_states=output_hidden_states
)
hidden_states=outputs['hidden_states']
# print(hidden_states[0].shape)
# print(outputs.keys())
location=(outputs['sequences'] == self.bounding_box_label).nonzero().flatten()
# print(self.bounding_box_label)
# print(len(outputs))
# print('location',location)
# print(len(hidden_states))
# print(hidden_states.shape)
predict_boundingbox=[]
if location.numel() != 0:
location=location[-1]
# print(outputs.shape)
# print(location)
# print(hidden_states[0].shape)
location_state=hidden_states[location]
# print(location_state.view(1,-1).shape)
predict_boundingbox=self.seg_layer(location_state.view(1,-1).float())
# # print(1111111111111111111)
# print(predict_boundingbox.cpu())
answers = []
for output_token in outputs['sequences']:
if output_token[0] == 0:
output_token = output_token[1:]
output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True)
output_texts = output_texts.split('</s>')[0] # remove the stop sign </s>
output_texts = output_texts.replace("<s>", "")
output_texts = output_texts.split(r'[/INST]')[-1].strip()
answers.append(output_texts)
# print(output_texts)
if predict_boundingbox!=[]:
return answers,predict_boundingbox
else:
return answers,[]
# print('answers',answers)
# return answers,predict_boundingbox
def predict_class_(self,images,reduction='mean'):
# img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device))
# print(type(img_embeds))
# print(img_embeds.shape)
cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets,input_embs,class_target,self_loss = \
self.preparing_embedding(images)
inputs_embeds, attention_mask, input_lens = \
self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts)
# get bos token embedding
bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id
bos_embeds = self.embed_tokens(bos)
bos_atts = cond_atts[:, :1]
# add bos token at the begining
inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1)
attention_mask = torch.cat([bos_atts, attention_mask], dim=1)
# ensemble the final targets
targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]],
dtype=torch.long).to(self.device).fill_(-100)
with self.maybe_autocast():
outputs = self.llama_model(
inputs_embeds=inputs_embeds,
attention_mask=attention_mask,
return_dict=True,
labels=None,
reduction=reduction
)
hidden_states=outputs.hidden_states
class_predict=self.class_layer(hidden_states[:,:32,:].view(1,-1).float())
# class_predict=self.class_layer(img_embeds.view(1, -1).float())
# if class_target!=None:
# classification_loss_fn=torch.nn.CrossEntropyLoss()
# class_loss=classification_loss_fn(class_predict,class_target)
return class_predict
def predict_boundingbox_(self,samples):
if 'several_modalities_image' in samples:
images=samples['several_modalities_image']['images']
modalities=samples['several_modalities_image']['modalities']
# print(modalities)
device=images[0].device
temp=[]
for index,image in enumerate(images):
# print(image.shape)
img_embeds, img_atts,self_loss = self.encode_img(image.to(self.device),modalities[index])
temp.append(img_embeds)
img_embeds=torch.cat(temp, dim=1)
# img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device))
x=img_embeds.float()
x=self.decoder_linear(x)
for blk in self.decoder_blocks:
x = blk(x)
x = self.decoder_norm(x)
# predictor projection
x = self.decoder_pred(x.view(1,-1))
# print(x.shape)
return x
@torch.no_grad()
def multi_select(self, images, texts, answers, num_cand=None):
all_losses = []
for answer in answers:
choice_samples = {
'image': images,
'instruction_input': texts,
'answer': answer
}
loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1)
all_losses.append(loss)
torch.cuda.empty_cache()
all_losses = torch.cat(all_losses, dim=-1)
if num_cand is not None:
for i in range(all_losses.shape[0]):
all_losses[i, num_cand[i]:] = 9999
output_class_ranks = torch.argsort(all_losses, dim=-1)
return output_class_ranks.tolist()
def sample_top_p(probs, p):
"""
Perform top-p (nucleus) sampling on a probability distribution.
Args:
probs (torch.Tensor): Probability distribution tensor.
p (float): Probability threshold for top-p sampling.
Returns:
torch.Tensor: Sampled token indices.
Note:
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
"""
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
probs_sum = torch.cumsum(probs_sort, dim=-1)
mask = probs_sum - probs_sort > p
probs_sort[mask] = 0.0
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
next_token = torch.multinomial(probs_sort, num_samples=1)
next_token = torch.gather(probs_idx, -1, next_token)
return next_token