| import logging |
| import random |
|
|
| import torch |
| from torch.cuda.amp import autocast as autocast |
| import torch.nn as nn |
| import open_clip |
| from minigpt4.common.registry import registry |
| from minigpt4.models.base_model import BaseModel |
| from transformers import StoppingCriteria, StoppingCriteriaList,AutoTokenizer, AutoModel |
| from minigpt4.models.base_model import LayerNorm |
| from minigpt4.conversation.conversation import StoppingCriteriaSub |
| from minigpt4.models.eva_vit import Block |
| from minigpt4.models.mae_vit import mae_vit_large_patch16 |
| |
| import math |
| import torch.nn.functional as F |
| from einops import rearrange |
| from einops.layers.torch import Rearrange |
| from open_clip.factory import HF_HUB_PREFIX, _MODEL_CONFIGS |
| import pdb |
|
|
| class SpatialPoolingProjector(nn.Module): |
| def __init__(self, image_size=(32, 256, 256), patch_size=(8, 16, 16), in_dim=768, out_dim=768, layer_type='linear', layer_num=6, pooling_type='spatial', pooling_size=2): |
| |
| super().__init__() |
| self.in_dim = in_dim |
| self.pooling_size = pooling_size |
|
|
| self.num_patches_pre = [img // pch for img, pch in zip(image_size, patch_size)] |
| |
| self.num_patches_post = [num // pooling_size for num in self.num_patches_pre] |
| |
| if layer_type == 'linear': |
| depth = int(layer_num) |
| modules = [nn.Linear(in_dim, out_dim)] |
| for _ in range(1, depth): |
| modules.append(nn.Linear(out_dim, out_dim)) |
| self.projector = nn.Sequential(*modules) |
| elif layer_type == 'mlp': |
| depth = int(layer_num) |
| modules = [nn.Linear(in_dim, out_dim)] |
| for _ in range(1, depth): |
| modules.append(nn.GELU()) |
| modules.append(nn.Linear(out_dim, out_dim)) |
| self.projector = nn.Sequential(*modules) |
| else: |
| print("Projector error!") |
|
|
| self.pooling_type = pooling_type |
|
|
| def forward(self, x): |
| B = x.shape[0] |
| |
| if self.pooling_type == 'spatial': |
| to_3d = Rearrange("b (p1 p2 p3) d -> b d p1 p2 p3", b=B, d=self.in_dim, p1=self.num_patches_pre[0], p2=self.num_patches_pre[1], p3=self.num_patches_pre[2]) |
| x = to_3d(x) |
| x = F.avg_pool3d(x, kernel_size=self.pooling_size, stride=self.pooling_size) |
| to_seq = Rearrange("b d p1 p2 p3 -> b (p1 p2 p3) d", b=B, d=self.in_dim, p1=self.num_patches_post[0], p2=self.num_patches_post[1], p3=self.num_patches_post[2]) |
| x = to_seq(x) |
| elif self.pooling_type == 'sequence': |
| x = x.permute(0, 2, 1) |
| x = F.avg_pool1d(x, kernel_size=self.pooling_size**3, stride=self.pooling_size**3) |
| x = x.permute(0, 2, 1) |
|
|
| x = rearrange(x, "b n d -> (b n) d") |
| |
| x = self.projector(x) |
| x = rearrange(x, "(b n) d -> b n d", b=B) |
|
|
| return x |
|
|
| @property |
| def proj_out_num(self): |
| num = 1 |
| for n in self.num_patches_post: |
| num *= n |
| return num |
|
|
|
|
|
|
| def calculate_iou(box1,box2): |
| y1max, x1max, y1min, x1min = box1[0], box1[1], box1[2], box1[3] |
| y2max, x2max, y2min, x2min = box2[0], box2[1], box2[2], box2[3] |
| |
| s1 = (y1max - y1min + 1.) * (x1max - x1min + 1.) |
| s2 = (y2max - y2min + 1.) * (x2max - x2min + 1.) |
|
|
| |
| xmin = max(x1min,x2min) |
| ymin = max(y1min,y2min) |
| xmax = min(x1max,x2max) |
| ymax = min(y1max,y2max) |
|
|
| inter_h = max(ymax - ymin + 1, 0) |
| inter_w = max(xmax - xmin + 1, 0) |
|
|
| intersection = inter_h * inter_w |
| union = s1 + s2 - intersection |
|
|
| |
| iou = intersection / union |
| return iou |
|
|
| def calculate_giou_loss(ground_truth, predicted): |
| """ |
| 计算GIOU损失 |
| """ |
| iou = calculate_iou(ground_truth, predicted) |
|
|
| |
| enclose_y1 = torch.min(ground_truth[:, 0], predicted[:, 0]) |
| enclose_x1 = torch.min(ground_truth[:, 1], predicted[:, 1]) |
| enclose_y2 = torch.max(ground_truth[:, 2], predicted[:, 2]) |
| enclose_x2 = torch.max(ground_truth[:, 3], predicted[:, 3]) |
|
|
| enclose_area = torch.clamp(enclose_x2 - enclose_x1, min=0) * torch.clamp(enclose_y2 - enclose_y1, min=0) |
| print(iou) |
| giou = iou - (enclose_area - calculate_iou(ground_truth, predicted)) / (enclose_area + 1e-10) |
|
|
| giou_loss = 1.0 - giou |
| return giou_loss.mean() |
|
|
|
|
| class mulit_modality_Attention(nn.Module): |
| def __init__(self, in_channels=16, out_channels=8,emb_dim=4096, output_dim=4096,att_dropout=0.0, aropout=0.0): |
| super(mulit_modality_Attention, self).__init__() |
| self.emb_dim = emb_dim |
| self.scale = emb_dim ** -0.5 |
|
|
| |
|
|
| self.Wq = nn.Linear(emb_dim, emb_dim) |
| self.Wk = nn.Linear(emb_dim, emb_dim) |
| self.Wv = nn.Linear(emb_dim, emb_dim) |
| |
| self.softmax = nn.Softmax(dim=-1) |
| self.output_linear=nn.Linear(emb_dim,emb_dim) |
| self.output_proj = nn.AvgPool2d(kernel_size=(2, 1), stride=(2, 1)) |
|
|
| |
| def forward(self, x_list, context=None, pad_mask=None): |
| ''' |
| |
| :param x: [1-4, 32, 4096] |
| :param context: [batch_szie, seq_len, emb_dim] |
| :param pad_mask: [batch_size, seq_len, seq_len] |
| :return: |
| ''' |
| b, h, w = x_list.shape |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| query_list=self.Wq(x_list) |
| key_list=self.Wk(x_list) |
| value_list=self.Wv(x_list) |
| |
| q_c,q_w,q_h=query_list.shape |
| |
| |
| |
| |
| |
| |
| |
| attention_score=torch.matmul(query_list,key_list.transpose(-1,-2)) |
| attention_score=self.softmax(attention_score*self.scale) |
| |
| |
| attention_score=attention_score.reshape(q_c,-1,q_w,q_w) |
| value_list=value_list.reshape(-1,q_c,q_w,q_h) |
| |
| output=torch.matmul(attention_score,value_list) |
| output=self.output_linear(output) |
| o_b,o_c,o_w,o_d=output.shape |
| output=output.reshape(1,o_b*o_c,o_w,o_d) |
| |
| if o_b*o_c==16: |
| output=self.output_proj(output) |
| output=output.reshape(1,-1,o_w,o_d) |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
|
|
| |
| |
|
|
| |
| |
| return output |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
|
|
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
|
|
| def attention(query, key, value, mask=None, dropout=None): |
| d_k = query.size(-1) |
| scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(d_k) |
| if mask is not None: |
| scores = scores.masked_fill(mask == 0, -1e9) |
| p_attn = F.softmax(scores, dim=-1) |
| if dropout is not None: |
| p_attn = dropout(p_attn) |
| return torch.matmul(p_attn, value), p_attn |
|
|
|
|
| class resolution_attention(nn.Module): |
|
|
| def __init__(self, in_channels=16, out_channels=8, emb_dim=768, output_dim=768, dropout=0.1, aropout=0.0): |
| super(resolution_attention, self).__init__() |
| self.emb_dim = emb_dim |
| self.Wq = nn.Linear(emb_dim, emb_dim) |
| self.Wk = nn.Linear(emb_dim, emb_dim) |
| self.Wv = nn.Linear(emb_dim, emb_dim) |
| self.attn = None |
| self.output_linear = nn.Linear(emb_dim,emb_dim) |
| self.dropout = nn.Dropout(p=dropout) |
| self.dropout_2 = nn.Dropout(p=dropout) |
| self.norm = nn.LayerNorm(emb_dim) |
| |
| def forward(self, LR_image, HR_image, context=None, mask=None): |
| ''' |
| :param x: [1-4, 32, 768] |
| :param context: [batch_szie, seq_len, emb_dim] |
| :param pad_mask: [batch_size, seq_len, seq_len] |
| :return: |
| ''' |
| version = 2 |
| if version == 0: |
| |
| LR_image = LR_image.reshape(-1,LR_image.size(-1)).reshape(1,-1,LR_image.size(-1)) |
| HR_image = HR_image.reshape(-1,HR_image.size(-1)).reshape(1,-1,HR_image.size(-1)) |
| elif version == 1: |
| |
| LR_image = LR_image.view(1, -1, 32, 1, LR_image.size(-1)) |
| HR_image = HR_image.view(1, -1, 32, 4, HR_image.size(-1)) |
| elif version == 2: |
| |
| |
| |
| LR_image = LR_image.view(1, -1, 32, 1, LR_image.size(-1)) |
| HR_image = HR_image.view(1, -1, 32, 4*196, HR_image.size(-1)) |
| query_list=self.Wq(LR_image) |
| key_list=self.Wk(HR_image) |
| value_list=self.Wv(HR_image) |
| x, self.attn = attention(query_list, key_list, value_list, mask=mask, dropout=self.dropout) |
| if version in [1, 2]: |
| |
| x = x.view(1,-1,LR_image.size(-1)) |
| query_list = query_list.view(1,-1,LR_image.size(-1)) |
| |
| x = self.output_linear(x) |
| |
| x = self.norm(query_list + self.dropout_2(x)) |
| return x |
|
|
|
|
| def bbox_giou_loss(pred_bboxes, target_bboxes): |
| """ |
| Compute the Generalized IoU (GIOU) loss between predicted bounding boxes and target bounding boxes. |
| |
| Args: |
| - pred_bboxes: Predicted bounding boxes, tensor of shape (N, 4), where N is the number of bounding boxes. |
| Each bounding box is represented as [ymax, xmax, ymin, xmin]. |
| - target_bboxes: Target bounding boxes, tensor of shape (N, 4), where N is the number of bounding boxes. |
| Each bounding box is represented as [ymax, xmax, ymin, xmin]. |
| |
| Returns: |
| - giou_loss: Computed GIOU loss. |
| """ |
| x_pred_max=torch.max(pred_bboxes[:,1],pred_bboxes[:,3]) |
| x_pred_min=torch.min(pred_bboxes[:,1],pred_bboxes[:,3]) |
| y_pred_min=torch.min(pred_bboxes[:,0],pred_bboxes[:,2]) |
| y_pred_max=torch.max(pred_bboxes[:,0],pred_bboxes[:,2]) |
|
|
| x_gt_max=torch.max(target_bboxes[:,1],target_bboxes[:,3]) |
| x_gt_min=torch.min(target_bboxes[:,1],target_bboxes[:,3]) |
| y_gt_min=torch.min(target_bboxes[:,0],target_bboxes[:,2]) |
| y_gt_max=torch.max(target_bboxes[:,0],target_bboxes[:,2]) |
|
|
|
|
| ymin = torch.max(y_pred_min, y_gt_min) |
| xmin = torch.max(x_pred_min, x_gt_min) |
| ymax = torch.min(y_pred_max, y_gt_max) |
| xmax = torch.min(x_pred_max, x_gt_max) |
|
|
| intersection = torch.clamp((xmax - xmin), min=0) * torch.clamp((ymax - ymin), min=0) |
| |
| pred_area = (x_pred_max - x_pred_min) * (y_pred_max - y_pred_min) |
| target_area = (x_gt_max - x_gt_min) * (y_gt_max - y_gt_min) |
| union = pred_area + target_area - intersection |
|
|
|
|
| iou=intersection / union |
| |
| |
| |
| giou_loss=1-iou |
| return giou_loss.mean() |
|
|
|
|
|
|
|
|
|
|
| class MiniGPTBase(BaseModel): |
| """ |
| Base class for MiniGPT-4 and MiniGPT-v2 |
| """ |
|
|
| def __init__( |
| self, |
| vit_model="eva_clip_g", |
| img_size=224, |
| drop_path_rate=0, |
| use_grad_checkpoint=False, |
| vit_precision="fp16", |
| freeze_vit=False, |
| llama_model="", |
| max_txt_len=32, |
| max_context_len=3800, |
| prompt_template="", |
| end_sym='\n', |
| low_resource=False, |
| device_8bit=0, |
| lora_r=8, |
| lora_target_modules=["q_proj", "v_proj"], |
| lora_alpha=32, |
| lora_dropout=0.05, |
| modality_number=5, |
| model_2d_or_3d="3d", |
| self_training=False |
| ): |
| super().__init__() |
| self.self_training=self_training |
| freeze_vit=not self_training |
| self.model_2d_or_3d=model_2d_or_3d |
| print('self.model_2d_or_3d',self.model_2d_or_3d) |
| print('self_training is ',self_training) |
| if self.self_training==False: |
| self.llama_model, self.llama_tokenizer = self.init_llm( |
| llama_model_path=llama_model, |
| low_resource=low_resource, |
| low_res_device=device_8bit, |
| lora_r=lora_r, |
| lora_target_modules=lora_target_modules, |
| lora_alpha=lora_alpha, |
| lora_dropout=lora_dropout, |
| ) |
| self.visual_encoder_list=[] |
| self.ln_vision_list=[] |
| self.modality_number=modality_number |
|
|
| |
| self.llama_model = self.llama_model.cuda() |
|
|
| |
| self.llama_tokenizer.add_special_tokens({'additional_special_tokens':["<box>", "<Img>", "</Img>", "<t>"]}) |
| |
| |
| |
| if len(self.llama_tokenizer) > self.llama_model.get_input_embeddings().weight.shape[0] and lora_r>0: |
| print("WARNING: Resizing the embedding matrix to match the tokenizer vocab size.") |
| self.llama_model.base_model.model.model.resize_token_embeddings(new_num_tokens = len(self.llama_tokenizer)) |
| else: |
| self.llama_model.base_model.resize_token_embeddings(new_num_tokens = len(self.llama_tokenizer)) |
| |
| |
| self.start_visual_token_idx = None |
| self.end_visual_token_idx = None |
|
|
| |
| |
| |
| |
| self.bounding_box_label = self.llama_tokenizer.convert_tokens_to_ids("<box>") |
| |
| |
| |
| |
| |
| |
| |
| self.cross_attention=resolution_attention() |
| self.L1_loss_fn=nn.SmoothL1Loss() |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
|
|
| for i in range(self.modality_number): |
| |
| |
| |
| |
| if self.model_2d_or_3d=='2d': |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
|
|
| |
| model, _, _ = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224') |
| ''' |
| 2024.09.02 关于视觉encoder的设置 |
| # 如果想要提取last hidden layer特征 |
| # load 配置文件 |
| vision_cfg = open_clip.CLIPVisionCfg |
| # 修改output_tokens参数,这个参数设成True,forward后会返回两个值,一个是原来的[CLS] token feature,另一个是全部token的feature |
| vision_cfg.output_tokens = True |
| # 把这个参数load进model |
| model, _, _ = open_clip.create_model_and_transforms('hf-hub:microsoft/BiomedCLIP-PubMedBERT_256-vit_base_patch16_224', vision_cfg = vision_cfg) |
| 【这个好像不能work】 这个biomedclip封装得很深,缺乏相应的参数来控制输出层。网上搜了很多,很多人都碰到了这个问题 |
| https://github.com/mlfoundations/open_clip/issues/844 |
| |
| 建议直接用Clip试试,后面可以用自己pretrain的CLIP替换 |
| # 提取1024维度的feature |
| from transformers import CLIPProcessor, CLIPVisionModel |
| model_path = "models/clip-vit-large-patch14" |
| processor = CLIPProcessor.from_pretrained(model_path) |
| vision_model = CLIPVisionModel.from_pretrained(model_path).to(device) |
| vision_model.to(device) |
| images = Image.open(path) |
| inputs = processor(images=images, return_tensors="pt", padding=True).to(device) |
| with torch.no_grad(): |
| vision_outputs = vision_model(**inputs) |
| #最后一层隐藏层特征,(257,1024), 可以当作local feature |
| last_hidden_state = vision_outputs.last_hidden_state |
| #CLS token 特征,(1024),可以当作global feature |
| pooler_output = vision_outputs.pooler_output |
| ''' |
| ''' |
| 采用hook提取biomedclip last hidden layer特征 |
| # 注册钩子 |
| ''' |
| self.visual_encoder = model.visual.trunk |
| |
| self.visual_encoder.requires_grad_(False) |
|
|
| is_unfrozen = True |
| if is_unfrozen: |
| |
| self.visual_encoder.norm.weight.requires_grad_(True) |
| self.visual_encoder.norm.bias.requires_grad_(True) |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.llama_proj_mlps = nn.Sequential( |
| nn.Linear(768, 4096), |
| |
| nn.GELU(), |
| |
| nn.Linear(4096, 4096), |
| ) |
|
|
| |
| elif self.model_2d_or_3d=='3d': |
| print('-----------------------------3d vit encoder-------------') |
| model = AutoModel.from_pretrained( |
| "GoodBaiBai88/M3D-CLIP", |
| |
| trust_remote_code=True) |
| self.visual_encoder=model.vision_encoder |
| |
|
|
| |
| load_vit=False |
| if load_vit==True: |
| state_dict_b = self.visual_encoder.state_dict() |
| keys_list = list(state_dict_b.keys()) |
| file_path="/home/ynwang/MiniGPT-4/minigpt4/ckpt/vit/checkpoint_59.pth" |
| print("load vit encoder") |
| ckpt = torch.load(file_path, map_location="cpu") |
| |
| state_dict_a = ckpt['model'] |
| for para_idx,(name_a, param_a) in enumerate(state_dict_a.items()): |
| if para_idx>=144: |
| break |
| else: |
| cur_key=keys_list[para_idx] |
| if state_dict_b[cur_key].shape!=param_a.shape: |
| print("shape unmatch:{}".format(cur_key)) |
| else: |
| print("load parameter:{}".format(cur_key)) |
| state_dict_b[cur_key]=param_a |
| msg = self.visual_encoder.load_state_dict(state_dict_b, strict=False) |
| |
| |
| |
| self.visual_encoder_list.append(self.visual_encoder.to(self.device)) |
|
|
| |
| self.max_txt_len = max_txt_len |
| self.max_context_len = max_context_len |
| self.end_sym = end_sym |
| |
| self.prompt_template = prompt_template |
| self.prompt_list = [] |
| self.grad_list=[] |
| |
| |
| |
| def freeze_model(self): |
| visual_num=0 |
| llama_num=0 |
| for n,p in self.named_parameters(): |
| |
| if 'visual_encoder' in n or 'llama_proj_mlps' in n or 'ln_vision' in n: |
| |
| p.requires_grad=True |
| visual_num+=p.numel() |
| else: |
| llama_num+=p.numel() |
| p.requires_grad=False |
| self.grad_list.append(p) |
| |
| def vit_to_cpu(self): |
| |
| |
| self.visual_encoder.to("cpu") |
| self.visual_encoder.float() |
|
|
| |
| |
| |
| def get_context_emb(self, prompt, img_list): |
| device = img_list[0].device |
| prompt_segs = prompt.split('<ImageHere>') |
| |
| |
| |
| seg_tokens = [ |
| self.llama_tokenizer( |
| seg, return_tensors="pt", add_special_tokens=i==0).to(device).input_ids |
| for i, seg in enumerate(prompt_segs) |
| ] |
| seg_embs = [self.embed_tokens(seg_t) for seg_t in seg_tokens] |
|
|
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| count_current_len=1 |
| old_version=1 |
| if old_version==0: |
| mixed_embs = [emb for pair in zip(seg_embs[:-1], img_list) for emb in pair] + [seg_embs[-1]] |
| mixed_embs = torch.cat(mixed_embs, dim=1) |
|
|
| |
| |
| bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) |
| |
| |
| |
| |
| |
| |
| eos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) |
|
|
| mixed_embs = torch.cat([bos_embed, mixed_embs,eos_embed], dim=1) |
| else: |
| image_list=[] |
| cur_att=1 |
| if cur_att==0: |
| for index in range(1,5): |
| image_list.append(img_list[0][:,32*(index-1):32*index]) |
| mixed_embs = [emb for pair in zip(seg_embs[:-1], image_list) for emb in pair] + [seg_embs[-1]] |
|
|
| |
| |
| |
| |
| |
| |
| mixed_embs=mixed_embs+[seg_embs[-1]] |
|
|
| |
| mixed_embs = torch.cat(mixed_embs, dim=1) |
| bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) |
| mixed_embs = torch.cat([bos_embed, mixed_embs], dim=1) |
| |
|
|
| |
|
|
| elif cur_att==1: |
| image_list=[] |
|
|
| |
|
|
| for index in range(1,5): |
| image_list.append(img_list[0][:,32*(index-1):32*index]) |
| |
|
|
| count_current_len = 1 |
| self.start_visual_token_idx = [] |
| self.end_visual_token_idx = [] |
| mixed_embs=[] |
| for pair in zip(seg_embs[:-1], image_list): |
| count_current_len+=1 |
| self.start_visual_token_idx.append(count_current_len) |
| for emb in pair: |
| |
| |
| mixed_embs.append(emb) |
| count_current_len += 32 |
| self.end_visual_token_idx.append(count_current_len) |
| self.llama_model.set_index(self.start_visual_token_idx,self.end_visual_token_idx) |
| mixed_embs=mixed_embs+[seg_embs[-1]] |
|
|
| mixed_embs = torch.cat(mixed_embs, dim=1) |
| bos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.bos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) |
| eos_embed = self.embed_tokens(self.llama_tokenizer(self.llama_tokenizer.eos_token, return_tensors="pt", add_special_tokens=False).to(mixed_embs.device).input_ids) |
| |
|
|
| mixed_embs = torch.cat([bos_embed, mixed_embs,eos_embed], dim=1) |
| return mixed_embs |
|
|
| def prompt_wrap(self, img_embeds, atts_img, prompts, lengths=None): |
| if prompts is None or len(prompts) == 0: |
| |
| return img_embeds, atts_img |
| elif img_embeds is None: |
| |
| self.llama_tokenizer.padding_side = "right" |
| prompt_tokens = self.llama_tokenizer( |
| prompts, |
| return_tensors="pt", |
| padding="longest", |
| add_special_tokens=False |
| ).to(self.device) |
| prompt_embeds = self.embed_tokens(prompt_tokens.input_ids) |
| atts_prompt = prompt_tokens.attention_mask |
| return prompt_embeds, atts_prompt |
| else: |
| |
| emb_lists = [] |
| if isinstance(prompts, str): |
| prompts = [prompts] * len(img_embeds) |
| self.start_visual_token_idx = None |
| self.end_visual_token_idx = None |
| cond_ids_list=[] |
| for idx, (each_img_embed, each_prompt) in enumerate(zip(img_embeds, prompts)): |
| pn = each_img_embed.shape[-2] |
| |
|
|
| if lengths is not None: |
| each_img_embed = each_img_embed.reshape(-1, each_img_embed.shape[-1]) |
| each_img_embed = each_img_embed[:lengths[idx] * pn] |
| p_segs = each_prompt.split('<ImageHere>') |
| interleave_emb = [] |
| interleave_ids=[] |
| '''2024.09.08 Yanzhaoshi修改: 以32为一组, 拆分视觉特征 <bos> <Img> <ImageHere> <t> <ImageHere> <t> <ImageHere> <t> <ImageHere> </Img> {instruction} ''' |
| count_current_len = 1 |
| self.start_visual_token_idx = [] |
| self.end_visual_token_idx = [] |
| if '<t>' in each_prompt: |
| |
| |
| vis_chunk_size = 32 |
| for idx, seg in enumerate(p_segs[:-1]): |
| |
| p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) |
| |
| interleave_ids += p_tokens.input_ids.squeeze(0).tolist() |
| |
| p_embed = self.embed_tokens(p_tokens.input_ids) |
| |
| count_current_len += p_embed.size(1) |
| |
| self.start_visual_token_idx.append(count_current_len) |
| |
| interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * vis_chunk_size:(idx + 1) * vis_chunk_size]], dim=1)) |
| |
| interleave_ids += self.llama_tokenizer('<t>', add_special_tokens=False).input_ids * (int(pn/4)) |
|
|
| |
| count_current_len += vis_chunk_size |
| |
| self.end_visual_token_idx.append(count_current_len) |
|
|
|
|
| else: |
| for idx, seg in enumerate(p_segs[:-1]): |
| |
| p_tokens = self.llama_tokenizer(seg, return_tensors="pt", add_special_tokens=False).to(img_embeds.device) |
| p_embed = self.embed_tokens(p_tokens.input_ids) |
| |
| count_current_len += p_embed.size(1) |
| |
| self.start_visual_token_idx.append(count_current_len) |
| |
| interleave_emb.append(torch.cat([p_embed, each_img_embed[None][:, idx * pn:(idx + 1) * pn]], dim=1)) |
| |
| count_current_len += pn |
| |
| self.end_visual_token_idx.append(count_current_len) |
|
|
| wrapped_emb = torch.cat(interleave_emb, dim=1) |
|
|
| |
| |
| |
| |
|
|
| p_tokens = self.llama_tokenizer( |
| p_segs[-1], return_tensors="pt", add_special_tokens=False).to(img_embeds.device) |
| |
| interleave_ids += p_tokens.input_ids.squeeze(0).tolist() |
| cond_ids_list.append(interleave_ids) |
| cond_ids_list = torch.tensor(cond_ids_list).to(img_embeds.device) |
|
|
| |
| |
| p_embed = self.embed_tokens(p_tokens.input_ids) |
| wrapped_emb = torch.cat([wrapped_emb, p_embed], dim=1) |
| |
| |
| |
| |
| |
| |
| |
| emb_lists.append(wrapped_emb) |
|
|
| emb_lens = [emb.shape[1] for emb in emb_lists] |
| pad_emb = self.embed_tokens(torch.tensor(self.llama_tokenizer.pad_token_id, device=img_embeds.device)) |
|
|
| max_length = max(emb_lens) if max(emb_lens) < self.max_context_len else self.max_context_len |
| wrapped_embs = pad_emb.expand(len(emb_lens), max_length, -1).clone() |
| wrapped_atts = torch.zeros([len(emb_lens), max_length], dtype=torch.int, device=img_embeds.device) |
| |
| for i, emb in enumerate(emb_lists): |
| length = emb_lens[i] if emb_lens[i] < self.max_context_len else self.max_context_len |
| wrapped_embs[i, :length] = emb[:, :length] |
| wrapped_atts[i, :length] = 1 |
| return wrapped_embs, wrapped_atts,cond_ids_list |
|
|
| def concat_emb_input_output(self, input_embs, input_atts, output_embs, output_atts): |
| """ |
| Concatenate the batched input embedding and batched output embedding together. |
| Both the input and the output embedding should be right padded. |
| """ |
| input_lens = [] |
| cat_embs = [] |
| cat_atts = [] |
| for i in range(input_embs.size(0)): |
| input_len = input_atts[i].sum() |
| input_lens.append(input_len) |
| cat_embs.append( |
| torch.cat([ |
| input_embs[i][:input_len], |
| output_embs[i], |
| input_embs[i][input_len:] |
| ]) |
| ) |
| cat_atts.append( |
| torch.cat([ |
| input_atts[i][:input_len], |
| output_atts[i], |
| input_atts[i][input_len:] |
| ]) |
| ) |
| cat_embs = torch.stack(cat_embs) |
| cat_atts = torch.stack(cat_atts) |
| return cat_embs, cat_atts, input_lens |
|
|
| def tokenize_conversation(self, conv_q, conv_a): |
| """concatenate conversation and make sure the model is only trained to regress the answer""" |
|
|
| to_regress_token_ids_list = [] |
| targets_list = [] |
|
|
| batch_size = len(conv_q) |
| for batch_idx in range(batch_size): |
| questions, answers = conv_q[batch_idx], conv_a[batch_idx] |
| questions = [self.llama_tokenizer(self.llama_tokenizer.bos_token + q, |
| return_tensors="pt", |
| add_special_tokens=False).to(self.device) for q in questions[1:]] |
| answers = [self.llama_tokenizer(a + self.end_sym, |
| return_tensors="pt", |
| add_special_tokens=False).to(self.device) for a in answers] |
| cur_id = [] |
| cur_target = [] |
| for i in range(len(questions)): |
| cur_id.append(answers[i].input_ids) |
| cur_target.append(answers[i].input_ids) |
| cur_id.append(questions[i].input_ids) |
| cur_target.append(torch.ones_like(questions[i].input_ids) * -100) |
|
|
| cur_id.append(answers[-1].input_ids) |
| cur_target.append(answers[-1].input_ids) |
|
|
| cur_id = torch.cat(cur_id, dim=1) |
| cur_target = torch.cat(cur_target, dim=1) |
| to_regress_token_ids_list.append(cur_id) |
| targets_list.append(cur_target) |
|
|
| max_len = min(max([target.shape[1] for target in targets_list]), self.max_txt_len) |
| to_regress_token_ids = torch.ones([batch_size, max_len], |
| dtype=cur_id.dtype, device=self.device) * self.llama_tokenizer.pad_token_id |
| targets = torch.ones([batch_size, max_len], |
| dtype=cur_id.dtype, device=self.device) * -100 |
| for batch_idx in range(batch_size): |
| cur_len = to_regress_token_ids_list[batch_idx].shape[1] |
| to_regress_token_ids[batch_idx, :cur_len] = to_regress_token_ids_list[batch_idx][0, :max_len] |
| targets[batch_idx, :cur_len] = targets_list[batch_idx][0, :max_len] |
|
|
| to_regress_token_attn = (to_regress_token_ids != self.llama_tokenizer.pad_token_id).to(torch.int) |
|
|
| return to_regress_token_ids, to_regress_token_attn, targets |
|
|
| def preparing_embedding(self, samples): |
| |
| if 'several_modalities_image' in samples: |
| |
| |
| images=samples['several_modalities_image']['images'] |
| |
| |
| modalities=samples['several_modalities_image']['modalities'] |
| |
| device=images[0].device |
| temp=[] |
| for index,image in enumerate(images): |
| |
| |
| |
| img_embeds, img_atts,self_loss=self.encode_img(image) |
|
|
| |
| temp.append(img_embeds) |
| img_embeds=torch.cat(temp, dim=1) |
| |
| elif 'image' in samples: |
| if type(samples["image"])==list: |
| images = torch.stack(samples["image"]).to(self.device) |
| |
| else: |
| images=samples["image"].to(self.device) |
| |
| |
| img_embeds, img_atts, self_loss = self.encode_img(images, "LR_Encoding") |
| |
| if "HR_resolution" in samples.keys() and samples['HR_resolution'] ==True: |
| HR_img_embeds, HR_img_atts, HR_self_loss = self.encode_img(samples["HR_image_list"].to(self.device), "HR_Encoding") |
| |
| |
| img_embeds=self.cross_attention(img_embeds.reshape(1,-1,768), HR_img_embeds.reshape(1,-1,768)) |
| |
| |
| |
| img_embeds=self.llama_proj_mlps(img_embeds) |
|
|
|
|
| device=images.device |
| else: |
| img_embeds = img_atts = None |
| |
| if 'conv_q' in samples: |
| |
| conv_q, conv_a = samples['conv_q'], samples['conv_a'] |
|
|
| connect_sym = samples['connect_sym'][0] |
| conv_q = [q.split(connect_sym)for q in conv_q] |
| conv_a = [a.split(connect_sym) for a in conv_a] |
|
|
| conv_q = [[self.prompt_template.format(item) for item in items] for items in conv_q] |
|
|
| cond_embeds, cond_atts,_ = self.prompt_wrap(img_embeds, img_atts, [q[0] for q in conv_q]) |
| regress_token_ids, regress_atts, part_targets = self.tokenize_conversation(conv_q, conv_a) |
|
|
| else: |
| if "instruction_input" in samples: |
| instruction = samples["instruction_input"] |
| elif self.prompt_list: |
| instruction = random.choice(self.prompt_list) |
| else: |
| instruction = None |
|
|
| if hasattr(self, 'chat_template') and self.chat_template: |
| instruction = [self.prompt_template.format(instruct) for instruct in instruction] |
|
|
| if 'length' in samples: |
| |
| bsz, pn, hs = img_embeds.shape |
| |
| img_embeds = img_embeds.reshape(len(samples['image']), -1, pn, hs) |
|
|
| cond_embeds, cond_atts,_ = self.prompt_wrap(img_embeds, img_atts, instruction, samples['length']) |
| else: |
| |
| |
| |
| |
|
|
| |
| |
| |
| if img_embeds.size(0)>8 and self.has_qformer==True: |
| img_embeds=img_embeds.reshape(img_embeds.size(0),img_embeds.size(1),64,64) |
| |
| |
| img_embeds=img_embeds.permute(1,0,2,3) |
| img_embeds=self.conv3d_layer(img_embeds) |
| img_embeds=img_embeds.reshape(img_embeds.size(0),img_embeds.size(1),4096) |
| img_embeds=self.conv3d_proj(img_embeds).reshape(1,-1,4096) |
| else: |
| img_embeds=img_embeds.reshape(1,-1,4096) |
| |
| |
|
|
| |
| |
| |
| cond_embeds, cond_atts,_ = self.prompt_wrap(img_embeds, img_atts, instruction) |
|
|
| |
| self.llama_tokenizer.padding_side = "right" |
| text = [t + self.end_sym for t in samples["answer"]] |
| |
| regress_tokens = self.llama_tokenizer( |
| text, |
| return_tensors="pt", |
| padding="longest", |
| truncation=True, |
| max_length=self.max_txt_len, |
| add_special_tokens=False |
| ).to(self.device) |
| |
| regress_token_ids = regress_tokens.input_ids |
| |
| regress_atts = regress_tokens.attention_mask |
| part_targets = regress_token_ids.masked_fill( |
| regress_token_ids == self.llama_tokenizer.pad_token_id, -100 |
| ) |
| |
| |
| if 'class_target' in samples: |
| class_target=samples['class_target'] |
| class_target=class_target.to(device) |
| else: |
| class_target=None |
| regress_embeds = self.embed_tokens(regress_token_ids) |
|
|
|
|
| return cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets, img_embeds, class_target, self_loss |
|
|
| def forward(self, samples, reduction='mean'): |
| |
| |
| |
| if self.self_training == False: |
| cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets,input_embs,class_target,self_loss = \ |
| self.preparing_embedding(samples) |
| else: |
| self.visual_encoder=self.visual_encoder_list[0] |
| image=samples['image'].to(self.device) |
| if len(image.shape) > 4: |
| image = image.reshape(-1, *image.shape[-3:]) |
| with self.maybe_autocast(): |
| loss,_,_,latent=self.visual_encoder(image) |
| |
| return {"loss": loss} |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| if class_target!=None: |
| img_num=input_embs.shape[1]/32 |
| |
| classification_loss_fn=torch.nn.CrossEntropyLoss() |
| |
| |
| class_loss=classification_loss_fn(class_predict,class_target) |
| else: |
| class_loss=None |
| inputs_embeds, attention_mask, input_lens = \ |
| self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) |
|
|
| |
| bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id |
| bos_embeds = self.embed_tokens(bos) |
| bos_atts = cond_atts[:, :1] |
|
|
| |
| |
| inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) |
| attention_mask = torch.cat([bos_atts, attention_mask], dim=1) |
|
|
| |
| targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], |
| dtype=torch.long).to(self.device).fill_(-100) |
|
|
| for i, target in enumerate(part_targets): |
| targets[i, input_lens[i]+1:input_lens[i]+len(target)+1] = target |
| |
| |
| |
| |
| |
| |
| |
| if self_loss==None: |
| with self.maybe_autocast(): |
| outputs = self.llama_model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| start_visual_idx = self.start_visual_token_idx, |
| end_visual_idx = self.end_visual_token_idx, |
| return_dict=True, |
| labels=targets, |
| reduction=reduction |
| ) |
| hidden_states=outputs.hidden_states[-1] |
| |
| |
| |
| |
| |
| if "bounding_box" in samples: |
| gt_bounding_box=samples["bounding_box"] |
| |
| |
| |
| output_token_index = (targets == self.bounding_box_label).nonzero() |
| |
| if len(output_token_index)>0: |
| location_state=hidden_states[output_token_index[0][-1]] |
| bounding_box=self.seg_layer(location_state.view(1, -1)) |
| |
| |
| |
| |
| |
| L1_loss=self.L1_loss_fn(bounding_box,gt_bounding_box) |
| |
| giou_loss=bbox_giou_loss(bounding_box,gt_bounding_box) |
| |
|
|
| if giou_loss>0: |
| |
| loss=outputs.loss+0.2*L1_loss+1.2*giou_loss |
| else: |
| loss=0.8*outputs.loss+0.2*L1_loss |
| else: |
| loss=outputs.loss |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| else: |
| loss=outputs.loss |
| |
|
|
| |
| |
| |
| |
| |
| if torch.isnan(loss): |
| print(samples['image_name']) |
| |
| return {"loss": loss} |
|
|
| def embed_tokens(self, token_ids): |
| if hasattr(self.llama_model.base_model, 'model'): |
| embeds = self.llama_model.base_model.model.model.embed_tokens(token_ids) |
| else: |
| embeds = self.llama_model.base_model.embed_tokens(token_ids) |
| return embeds |
|
|
| @torch.no_grad() |
| def multi_generate( |
| self, |
| samples, |
| texts, |
| num_beams=1, |
| max_new_tokens=160, |
| min_length=1, |
| top_p=0.9, |
| repetition_penalty=1, |
| length_penalty=1, |
| temperature=1, |
| do_sample=False, |
| stop_words_ids=[2], |
| ): |
| ''' |
| function for generate test use |
| ''' |
|
|
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( |
| stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) |
| |
| |
|
|
| if 'several_modalities_image' in samples: |
| images=samples['several_modalities_image']['images'] |
| |
| modalities=samples['several_modalities_image']['modalities'] |
| |
| |
| temp=[] |
| for index,image in enumerate(images): |
| |
| |
| img_embeds, img_atts,self_loss = self.encode_img(image.to(self.device).half(),modalities[index]) |
| |
| temp.append(img_embeds) |
| img_embeds=torch.cat(temp, dim=1) |
|
|
| image_lists = [[image_emb[None]] for image_emb in img_embeds] |
| |
| |
| |
| batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] |
|
|
| batch_size = len(batch_embs) |
| max_len = max([emb.shape[1] for emb in batch_embs]) |
| emb_dim = batch_embs[0].shape[2] |
| dtype = batch_embs[0].dtype |
| device = batch_embs[0].device |
|
|
| embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) |
| attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) |
| for i, emb in enumerate(batch_embs): |
| emb_len = emb.shape[1] |
| embs[i, -emb_len:] = emb[0] |
| attn_mask[i, -emb_len:] = 1 |
| with self.maybe_autocast(): |
| outputs = self.llama_model.generate( |
| inputs_embeds=embs, |
| attention_mask=attn_mask, |
| max_new_tokens=max_new_tokens, |
| num_beams=num_beams, |
| length_penalty=length_penalty, |
| temperature=temperature, |
| do_sample=do_sample, |
| min_length=min_length, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| stopping_criteria=stopping_criteria, |
| ) |
|
|
| answers = [] |
| for output_token in outputs: |
| if output_token[0] == 0: |
| output_token = output_token[1:] |
| output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) |
| output_texts = output_texts.split('</s>')[0] |
| output_texts = output_texts.replace("<s>", "") |
| output_texts = output_texts.split(r'[/INST]')[-1].strip() |
| answers.append(output_texts) |
|
|
| return answers |
| |
| @torch.no_grad() |
| def generate( |
| self, |
| images, |
| texts, |
| num_beams=1, |
| max_new_tokens=300, |
| min_length=1, |
| top_p=0.9, |
| repetition_penalty=1, |
| length_penalty=1, |
| temperature=1, |
| do_sample=False, |
| stop_words_ids=[2], |
| ): |
| ''' |
| function for generate test use |
| ''' |
|
|
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( |
| stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) |
| |
| img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device)) |
|
|
| image_lists = [[image_emb[None]] for image_emb in img_embeds] |
| |
| |
| |
| batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] |
|
|
| batch_size = len(batch_embs) |
| max_len = max([emb.shape[1] for emb in batch_embs]) |
| emb_dim = batch_embs[0].shape[2] |
| dtype = batch_embs[0].dtype |
| device = batch_embs[0].device |
|
|
| embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) |
| attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) |
| for i, emb in enumerate(batch_embs): |
| emb_len = emb.shape[1] |
| embs[i, -emb_len:] = emb[0] |
| attn_mask[i, -emb_len:] = 1 |
|
|
| with self.maybe_autocast(): |
| outputs = self.llama_model.generate( |
| inputs_embeds=embs, |
| attention_mask=attn_mask, |
| max_new_tokens=max_new_tokens, |
| num_beams=num_beams, |
| length_penalty=length_penalty, |
| temperature=temperature, |
| do_sample=do_sample, |
| min_length=min_length, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| stopping_criteria=stopping_criteria, |
| ) |
|
|
| answers = [] |
| for output_token in outputs: |
| if output_token[0] == 0: |
| output_token = output_token[1:] |
| output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) |
| output_texts = output_texts.split('</s>')[0] |
| output_texts = output_texts.replace("<s>", "") |
| output_texts = output_texts.split(r'[/INST]')[-1].strip() |
| answers.append(output_texts) |
|
|
| return answers |
|
|
| |
| |
| |
| |
| |
| |
| |
| def predict_class(self,samples,device): |
| |
| |
| |
| images=samples['several_modalities_image']['images'] |
| |
| modalities=samples['several_modalities_image']['modalities'] |
| |
| |
| temp=[] |
| num=len(images) |
| for index,image in enumerate(images): |
| |
| |
| img_embeds, img_atts,self_loss=self.encode_img(image.to(device).half(),modalities[index]) |
| |
| temp.append(img_embeds) |
| img_embeds=torch.cat(temp, dim=1) |
|
|
| |
| |
| |
| |
| return class_predict |
|
|
|
|
| @torch.no_grad() |
| def predict_boundingbox( |
| self, |
| images, |
| texts, |
| HR_images=None, |
| num_beams=1, |
| max_new_tokens=300, |
| min_length=1, |
| top_p=0.9, |
| repetition_penalty=1, |
| length_penalty=1, |
| temperature=1, |
| do_sample=False, |
| stop_words_ids=[2], |
| output_hidden_states=True |
| ): |
| ''' |
| function for generate test use |
| ''' |
|
|
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( |
| stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) |
| |
| img_embeds, atts_img, self_loss = self.encode_img(images.to(self.device), "LR_Encoding") |
|
|
| if HR_images!=[]: |
| |
| HR_embeds,HR_atts_img,HR_loss = self.encode_img(HR_images.to(self.device), "HR_Encoding") |
| img_embeds = self.cross_attention(img_embeds.reshape(1,-1,768).float(),HR_embeds.reshape(1,-1,768).float()) |
| |
| img_embeds=self.llama_proj_mlps(img_embeds) |
| img_embeds=img_embeds.reshape(1,-1,*img_embeds.shape[-1:]) |
|
|
| if img_embeds.size(0)>8: |
| |
| img_embeds=img_embeds.reshape(1,-1,4096) |
| |
|
|
| |
| else: |
| img_embeds=img_embeds.reshape(1,-1,4096) |
| |
| |
| |
| image_lists = [[image_emb[None]] for image_emb in img_embeds] |
|
|
| batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] |
|
|
| batch_size = len(batch_embs) |
| max_len = max([emb.shape[1] for emb in batch_embs]) |
| emb_dim = batch_embs[0].shape[2] |
| dtype = batch_embs[0].dtype |
| device = batch_embs[0].device |
|
|
| embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) |
| attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) |
| for i, emb in enumerate(batch_embs): |
| emb_len = emb.shape[1] |
| embs[i, -emb_len:] = emb[0] |
| attn_mask[i, -emb_len:] = 1 |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| with self.maybe_autocast(): |
| outputs = self.llama_model.generate( |
| inputs_embeds=embs, |
| attention_mask=attn_mask, |
| max_new_tokens=max_new_tokens, |
| num_beams=num_beams, |
| length_penalty=length_penalty, |
| temperature=temperature, |
| do_sample=do_sample, |
| min_length=min_length, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| stopping_criteria=stopping_criteria, |
| return_dict_in_generate=True, |
| output_hidden_states=output_hidden_states, |
| pad_token_id=2, |
| eos_token_id=2 |
| ) |
| hidden_states=outputs['hidden_states'] |
| |
|
|
| answers = [] |
| top_2_list=self.llama_model.get_top() |
| if top_2_list[0]!=None and top_2_list[1]!=None: |
| top2_answer=self.llama_tokenizer.decode(top_2_list, skip_special_tokens=True) |
| |
| self.llama_model.delete_max_id() |
|
|
| for output_token in outputs['sequences']: |
| if output_token[0] == 0: |
| output_token = output_token[1:] |
| |
| output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) |
| ''' |
| 打印了一下output_texts,为什么会这么长,不是应该生成完前面的答案之后,自己预测出结束符然后结束吗?换句话说后面应该都是结束符才对,而结束符会在decoder时skip掉 |
| |
| 'This patient was diagnosed with Embryonal tumours### Short diagnosis: Embryonal tumours### Diagnosis: MedulloblastomaThis patient was diagnosed |
| with Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### |
| Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal |
| tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: |
| Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### |
| Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal |
| tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: |
| Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### Diagnosis: Embryonal tumours### Diagnosis: Medulloblastoma### |
| Diagnosis: Embryonal tum' |
| |
| 后面几行代码看起来像处理结束符的,训练时有加上结束符吗 |
| ''' |
| output_texts = output_texts.split('</s>')[0] |
| output_texts = output_texts.replace("<s>", "") |
| output_texts = output_texts.split(r'[/INST]')[-1].strip() |
| answers.append(output_texts) |
| |
| |
| |
| return answers |
| def get_input_embeddings(self): |
| return self.llama_model.get_input_embeddings() |
| |
| @torch.no_grad() |
| def generate_step( |
| self, |
| images, |
| texts, |
| HR_images=None, |
| num_beams=1, |
| echo=False, |
| max_new_tokens=500, |
| min_length=1, |
| top_p=0.9, |
| repetition_penalty=1, |
| length_penalty=1, |
| temperature=-1, |
| do_sample=False, |
| stop_words_ids=[2], |
| output_hidden_states=True |
| ): |
| ''' |
| function for generate test use |
| ''' |
| |
| |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( |
| stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) |
| |
| |
| img_embeds, atts_img, self_loss = self.encode_img(images.to(self.device), "LR_Encoding") |
|
|
| if HR_images!=[]: |
| |
| HR_embeds,HR_atts_img,HR_loss = self.encode_img(HR_images.to(self.device), "HR_Encoding") |
| img_embeds = self.cross_attention(img_embeds.reshape(1,-1,768).float(),HR_embeds.reshape(1,-1,768).float()) |
| |
| img_embeds=self.llama_proj_mlps(img_embeds) |
| bsz=img_embeds.shape[0] |
| img_embeds=img_embeds.reshape(1,-1,*img_embeds.shape[-1:]) |
| |
| if img_embeds.size(0)>8: |
| |
| img_embeds=img_embeds.reshape(1,-1,4096) |
| |
|
|
| |
| else: |
| img_embeds=img_embeds.reshape(1,-1,4096) |
| |
| |
| |
| |
|
|
| |
| inputs_embeds,attention_mask,input_ids=self.prompt_wrap(img_embeds,atts_img,texts) |
|
|
| |
| bos = torch.ones_like(attention_mask[:, :1]) * self.llama_tokenizer.bos_token_id |
| bos_embeds = self.embed_tokens(bos) |
| bos_atts = attention_mask[:, :1] |
|
|
| |
| inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) |
| |
| attention_mask = torch.cat([bos_atts, attention_mask], dim=1) |
| input_ids = torch.cat([bos, input_ids], dim=1) |
|
|
|
|
| |
| |
| |
|
|
|
|
| |
| batch_size = len(inputs_embeds) |
| |
| |
| emb_dim = inputs_embeds.shape[2] |
| dtype = inputs_embeds.dtype |
| device = inputs_embeds.device |
| min_prompt_len = min(len(t) for t in inputs_embeds) |
| max_prompt_len = max(len(t) for t in inputs_embeds) |
| total_len = max_new_tokens + min_prompt_len |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self.pad_token_id=self.llama_tokenizer.pad_token_id |
| pad_id = self.pad_token_id |
| tokens = torch.full((batch_size, total_len), pad_id, dtype=torch.int, device=device) |
| embs = torch.zeros([batch_size, total_len, emb_dim], dtype=dtype, device=device) |
| for k, t in enumerate(inputs_embeds): |
| embs[k, : len(t)] = t |
| for k, t in enumerate(input_ids): |
| tokens[k, : len(t)] = t |
| |
| prev_pos=0 |
| eos_reached = torch.tensor([False] * bsz, device=self.device) |
| input_text_mask = tokens != pad_id |
| self.stop_token_id = int(self.llama_tokenizer.encode("###")[-1]) |
| stop_tokens = torch.tensor([self.stop_token_id], device=self.device) |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| past_key_values=None |
| |
| |
| |
| with_probs = False |
| calculate_prob=False |
| |
| |
| for cur_pos in range(min_prompt_len, total_len): |
| with self.maybe_autocast(): |
| outputs = self.llama_model( |
| |
| |
| past_key_values=past_key_values, |
| inputs_embeds=embs[:, prev_pos:cur_pos], |
| use_cache=True, |
| output_hidden_states=output_hidden_states, |
| return_dict=True, |
| start_visual_idx = self.start_visual_token_idx, |
| end_visual_idx = self.end_visual_token_idx, |
| ) |
| |
| logits = outputs["logits"] |
| past_key_values = outputs["past_key_values"] |
|
|
| if temperature > 0: |
| probs = torch.softmax(logits[:, -1] / temperature, dim=-1) |
| next_token = sample_top_p(probs, top_p) |
| else: |
| next_token = torch.argmax(logits[:, -1], dim=-1) |
|
|
| next_token = next_token.reshape(-1) |
| |
| next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token) |
| import copy |
| if calculate_prob: |
| probs = torch.softmax(logits[:, -1], dim=-1) |
| index_list=[15,16,17,18,19,20,21,22,23,24,605,806,717,1032] |
| new_logits=copy.deepcopy(logits[0][0]) |
| tumor_list=[new_logits[k] for k in index_list] |
| tumor_list = torch.stack(tumor_list) |
| probabilities = torch.softmax(tumor_list, dim=0) |
| token_dict={probabilities[k]:index_list[k] for k in range(len(index_list))} |
| top_2=sorted(token_dict.keys(), reverse=True)[:2] |
| |
| top1=token_dict[top_2[0]] |
| top2=token_dict[top_2[1]] |
| print(f'<class_{self.llama_tokenizer.decode(top1)}>',top_2[0].item(),f'<class_{self.llama_tokenizer.decode(top2)}>',top_2[1].item()) |
| with_probs=False |
| calculate_prob=False |
| if next_token==449: |
| with_probs=True |
| if with_probs: |
| if next_token==62: |
| calculate_prob=True |
| |
| tokens[:, cur_pos] = next_token |
| next_token_embed = self.llama_model.get_input_embeddings()(next_token) |
| embs[:, cur_pos] = next_token_embed |
| |
| eos_reached |= (~input_text_mask[:, cur_pos]) & ( |
| torch.isin(next_token, stop_tokens.to(self.device)) |
| ) |
| prev_pos = cur_pos |
| if all(eos_reached): |
| break |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| out_tokens, out_logprobs = [], [] |
| for i, toks in enumerate(tokens.tolist()): |
| |
| start = 0 if echo else len(input_ids[i]) |
| toks = toks[start: len(input_ids[i]) + max_new_tokens] |
| probs = None |
| |
| |
| |
| for stop_token in [self.stop_token_id]: |
| try: |
| eos_idx = toks.index(stop_token) |
| toks = toks[:eos_idx] |
| |
| except ValueError: |
| pass |
| out_tokens.append(toks) |
| |
| answers = self.llama_tokenizer.decode(out_tokens[0]) |
| |
| return answers |
| |
| |
| |
| def generate_segmentation( |
| self, |
| images, |
| texts, |
| num_beams=1, |
| max_new_tokens=300, |
| min_length=1, |
| top_p=0.9, |
| repetition_penalty=1, |
| length_penalty=1, |
| temperature=1, |
| do_sample=False, |
| stop_words_ids=[2], |
| output_hidden_states=True |
| ): |
| ''' |
| function for generate test use |
| ''' |
|
|
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub( |
| stops=[torch.tensor([i]).to(self.device) for i in stop_words_ids])]) |
| |
| img_embeds, atts_img,self_loss = self.encode_img(images.to(self.device)) |
| |
| image_lists = [[image_emb[None]] for image_emb in img_embeds] |
| |
| |
| |
| batch_embs = [self.get_context_emb(text, img_list) for text, img_list in zip(texts, image_lists)] |
|
|
| batch_size = len(batch_embs) |
| max_len = max([emb.shape[1] for emb in batch_embs]) |
| emb_dim = batch_embs[0].shape[2] |
| dtype = batch_embs[0].dtype |
| device = batch_embs[0].device |
|
|
| embs = torch.zeros([batch_size, max_len, emb_dim], dtype=dtype, device=device) |
| attn_mask = torch.zeros([batch_size, max_len], dtype=torch.int, device=device) |
| for i, emb in enumerate(batch_embs): |
| emb_len = emb.shape[1] |
| embs[i, -emb_len:] = emb[0] |
| attn_mask[i, -emb_len:] = 1 |
|
|
| with self.maybe_autocast(): |
| outputs = self.llama_model.generate( |
| inputs_embeds=embs, |
| attention_mask=attn_mask, |
| max_new_tokens=max_new_tokens, |
| num_beams=num_beams, |
| length_penalty=length_penalty, |
| temperature=temperature, |
| do_sample=do_sample, |
| min_length=min_length, |
| top_p=top_p, |
| repetition_penalty=repetition_penalty, |
| stopping_criteria=stopping_criteria, |
| return_dict_in_generate=True, |
|
|
| output_hidden_states=output_hidden_states |
| |
| ) |
| hidden_states=outputs['hidden_states'] |
| |
| |
| |
| location=(outputs['sequences'] == self.bounding_box_label).nonzero().flatten() |
| |
| |
| |
| |
| |
| predict_boundingbox=[] |
| if location.numel() != 0: |
| location=location[-1] |
| |
| |
| |
| location_state=hidden_states[location] |
| |
| predict_boundingbox=self.seg_layer(location_state.view(1,-1).float()) |
| |
| |
| answers = [] |
| for output_token in outputs['sequences']: |
| if output_token[0] == 0: |
| output_token = output_token[1:] |
| |
| output_texts = self.llama_tokenizer.decode(output_token, skip_special_tokens=True) |
| output_texts = output_texts.split('</s>')[0] |
| output_texts = output_texts.replace("<s>", "") |
| output_texts = output_texts.split(r'[/INST]')[-1].strip() |
| answers.append(output_texts) |
| |
| if predict_boundingbox!=[]: |
| return answers,predict_boundingbox |
| else: |
| return answers,[] |
| |
| |
| |
|
|
|
|
|
|
| def predict_class_(self,images,reduction='mean'): |
| |
| |
| |
| cond_embeds, cond_atts, regress_embeds, regress_atts, part_targets,input_embs,class_target,self_loss = \ |
| self.preparing_embedding(images) |
|
|
| inputs_embeds, attention_mask, input_lens = \ |
| self.concat_emb_input_output(cond_embeds, cond_atts, regress_embeds, regress_atts) |
|
|
| |
| bos = torch.ones_like(part_targets[:, :1]) * self.llama_tokenizer.bos_token_id |
| bos_embeds = self.embed_tokens(bos) |
| bos_atts = cond_atts[:, :1] |
|
|
| |
| inputs_embeds = torch.cat([bos_embeds, inputs_embeds], dim=1) |
| attention_mask = torch.cat([bos_atts, attention_mask], dim=1) |
|
|
| |
| targets = torch.ones([inputs_embeds.shape[0], inputs_embeds.shape[1]], |
| dtype=torch.long).to(self.device).fill_(-100) |
|
|
| |
|
|
| with self.maybe_autocast(): |
| outputs = self.llama_model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| return_dict=True, |
| labels=None, |
| reduction=reduction |
| ) |
| hidden_states=outputs.hidden_states |
| class_predict=self.class_layer(hidden_states[:,:32,:].view(1,-1).float()) |
|
|
|
|
| |
| |
| |
| |
| return class_predict |
|
|
| def predict_boundingbox_(self,samples): |
| if 'several_modalities_image' in samples: |
| images=samples['several_modalities_image']['images'] |
| modalities=samples['several_modalities_image']['modalities'] |
| |
| device=images[0].device |
| temp=[] |
| for index,image in enumerate(images): |
| |
| img_embeds, img_atts,self_loss = self.encode_img(image.to(self.device),modalities[index]) |
| temp.append(img_embeds) |
| img_embeds=torch.cat(temp, dim=1) |
| |
| x=img_embeds.float() |
| x=self.decoder_linear(x) |
| for blk in self.decoder_blocks: |
| x = blk(x) |
| x = self.decoder_norm(x) |
|
|
| |
| x = self.decoder_pred(x.view(1,-1)) |
| |
| return x |
|
|
| @torch.no_grad() |
| def multi_select(self, images, texts, answers, num_cand=None): |
| all_losses = [] |
| for answer in answers: |
| choice_samples = { |
| 'image': images, |
| 'instruction_input': texts, |
| 'answer': answer |
| } |
| loss = self.forward(choice_samples, reduction='none')['loss'].reshape(-1, 1) |
| all_losses.append(loss) |
| torch.cuda.empty_cache() |
| all_losses = torch.cat(all_losses, dim=-1) |
| if num_cand is not None: |
| for i in range(all_losses.shape[0]): |
| all_losses[i, num_cand[i]:] = 9999 |
| output_class_ranks = torch.argsort(all_losses, dim=-1) |
| return output_class_ranks.tolist() |
| |
| def sample_top_p(probs, p): |
| """ |
| Perform top-p (nucleus) sampling on a probability distribution. |
| |
| Args: |
| probs (torch.Tensor): Probability distribution tensor. |
| p (float): Probability threshold for top-p sampling. |
| |
| Returns: |
| torch.Tensor: Sampled token indices. |
| |
| Note: |
| Top-p sampling selects the smallest set of tokens whose cumulative probability mass |
| exceeds the threshold p. The distribution is renormalized based on the selected tokens. |
| """ |
| probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True) |
| probs_sum = torch.cumsum(probs_sort, dim=-1) |
| mask = probs_sum - probs_sort > p |
| probs_sort[mask] = 0.0 |
| probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True)) |
| next_token = torch.multinomial(probs_sort, num_samples=1) |
| next_token = torch.gather(probs_idx, -1, next_token) |
| return next_token |