from abc import ABC, abstractmethod import torch import torch.nn as nn from PIL.ImImagePlugin import split from .multimodal_encoder.builder import build_vision_tower from ChatUniVi.constants import * from .cluster import CTM, TCBlock from collections import OrderedDict from .multimodal_projector.builder import build_vision_projector class MetaModel: def __init__(self, config): super(MetaModel, self).__init__(config) if hasattr(config, "mm_vision_tower"): self.vision_tower = build_vision_tower(config, delay_load=True) self.mm_projector = nn.Linear(config.mm_hidden_size, config.hidden_size) if hasattr(config, "config"): self.use_cluster = config.config["use_cluster"] if self.use_cluster: self.ctm0 = CTM(sample_ratio=config.config["spatial_cluster_rate0"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5) self.block0 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8) self.ctm1 = CTM(sample_ratio=config.config["spatial_cluster_rate1"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3) self.block1 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8) self.ctm2 = CTM(sample_ratio=config.config["spatial_cluster_rate2"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3) self.block2 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8) self.ctm3 = CTM(sample_ratio=config.config["temporal_cluster_rate"], embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5) self.block3 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8) else: self.use_cluster = False def get_vision_tower(self): vision_tower = getattr(self, 'vision_tower', None) if type(vision_tower) is list: vision_tower = vision_tower[0] return vision_tower def initialize_vision_modules(self, model_args, fsdp=None): vision_tower = model_args.vision_tower mm_vision_select_layer = model_args.mm_vision_select_layer mm_vision_select_feature = model_args.mm_vision_select_feature pretrain_mm_mlp_adapter = model_args.pretrain_mm_mlp_adapter self.config.mm_vision_tower = vision_tower vision_tower = build_vision_tower(model_args) self.config.use_mm_proj = True self.config.mm_hidden_size = vision_tower.hidden_size self.config.mm_vision_select_layer = mm_vision_select_layer self.config.mm_vision_select_feature = mm_vision_select_feature if fsdp is not None and len(fsdp) > 0: self.vision_tower = [vision_tower] else: self.vision_tower = vision_tower if not hasattr(self, 'mm_projector'): self.mm_projector = build_vision_projector(self.config) if pretrain_mm_mlp_adapter is not None: mm_projector_weights = torch.load(pretrain_mm_mlp_adapter, map_location='cpu') def get_w(weights, keyword): return {k.split(keyword + '.')[1]: v for k, v in weights.items() if keyword in k} self.mm_projector.load_state_dict(get_w(mm_projector_weights, 'mm_projector')) def initialize_cluster_modules(self, model_args): self.use_cluster = model_args.use_cluster if self.use_cluster and not hasattr(self, 'ctm0'): self.ctm0 = CTM(sample_ratio=model_args.spatial_cluster_rate0, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5) self.block0 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8) self.ctm1 = CTM(sample_ratio=model_args.spatial_cluster_rate1, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3) self.block1 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8) self.ctm2 = CTM(sample_ratio=model_args.spatial_cluster_rate2, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=3) self.block2 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8) self.ctm3 = CTM(sample_ratio=model_args.temporal_cluster_rate, embed_dim=self.config.mm_hidden_size, dim_out=self.config.mm_hidden_size, k=5) self.block3 = TCBlock(dim=self.config.mm_hidden_size, num_heads=8) class ChatUniViMetaForCausalLM(ABC): @abstractmethod def get_model(self): pass def get_vision_tower(self): return self.get_model().get_vision_tower() def encode_images(self, images): image_features = self.get_model().get_vision_tower()(images, select_feature="patch") return image_features def positional_encoding(self, x, num_features=1024, max_len=64): p = torch.zeros((1, max_len, num_features)) _x = torch.arange(max_len, dtype=torch.float32).reshape(-1, 1) / torch.pow(10000, torch.arange(0, num_features, 2, dtype=torch.float32) / num_features) p[:, :, 0::2] = torch.sin(_x) p[:, :, 1::2] = torch.cos(_x) x = x + p[:, :x.shape[1], :].to(x.device).to(x.dtype) return x def project(self, image_features, input_type="image"): if self.get_model().use_cluster: if input_type == "image": cluster_image_features = [] token_dict = {'x': image_features, 'token_num': image_features.size(1), 'idx_token': torch.arange(image_features.size(1))[None, :].repeat( image_features.size(0), 1), 'agg_weight': image_features.new_ones(image_features.size(0), image_features.size(1), 1), 'mask': None} token_dict = self.get_model().block0(self.get_model().ctm0(token_dict)) cluster_image_features.append(token_dict["x"]) token_dict = self.get_model().block1(self.get_model().ctm1(token_dict)) cluster_image_features.append(token_dict["x"]) token_dict = self.get_model().block2(self.get_model().ctm2(token_dict)) cluster_image_features.append(token_dict["x"]) image_features = torch.cat(cluster_image_features, dim=1) image_features = image_features.to(self.get_model().mm_projector.weight.dtype) else: cls_features = torch.mean(image_features, dim=1, keepdim=False).unsqueeze(0).clone() token_dict = {'x': cls_features, 'token_num': cls_features.size(1), 'idx_token': torch.arange(cls_features.size(1))[None, :].repeat( cls_features.size(0), 1), 'agg_weight': cls_features.new_ones(cls_features.size(0), cls_features.size(1), 1), 'mask': None} down_dict, token_dict = self.get_model().ctm3(token_dict) events = OrderedDict() max_len = 0 for id, i in enumerate(down_dict["idx_token"][0].tolist()): if i not in events: events[i] = [id] else: events[i].append(id) max_len = len(events[i]) if max_len < len(events[i]) else max_len cluster_image_features = [] token_dict = {'x': image_features, 'token_num': image_features.size(1), 'idx_token': torch.arange(image_features.size(1))[None, :].repeat( image_features.size(0), 1), 'agg_weight': image_features.new_ones(image_features.size(0), image_features.size(1), 1), 'mask': None} token_dict0 = self.get_model().block0(self.get_model().ctm0(token_dict)) token_dict1 = self.get_model().block1(self.get_model().ctm1(token_dict0)) token_dict2 = self.get_model().block2(self.get_model().ctm2(token_dict1)) for id, key in enumerate(events): cur_image_features0 = torch.cat([token_dict0["x"][i] for i in events[key]], dim=0).unsqueeze(0) token_dict = {'x': cur_image_features0, 'token_num': cur_image_features0.size(1), 'idx_token': torch.arange(cur_image_features0.size(1))[None, :].repeat( cur_image_features0.size(0), 1), 'agg_weight': cur_image_features0.new_ones(cur_image_features0.size(0), cur_image_features0.size(1), 1), 'mask': None} cur_token_dict0 = self.get_model().block0(self.get_model().ctm0(token_dict)) cluster_image_features.append(cur_token_dict0["x"]) cur_image_features1 = torch.cat([token_dict1["x"][i] for i in events[key]], dim=0).unsqueeze(0) token_dict = {'x': cur_image_features1, 'token_num': cur_image_features1.size(1), 'idx_token': torch.arange(cur_image_features1.size(1))[None, :].repeat( cur_image_features1.size(0), 1), 'agg_weight': cur_image_features1.new_ones(cur_image_features1.size(0), cur_image_features1.size(1), 1), 'mask': None} cur_token_dict1 = self.get_model().block1(self.get_model().ctm1(token_dict)) cluster_image_features.append(cur_token_dict1["x"]) cur_image_features2 = torch.cat([token_dict2["x"][i] for i in events[key]], dim=0).unsqueeze(0) token_dict = {'x': cur_image_features2, 'token_num': cur_image_features2.size(1), 'idx_token': torch.arange(cur_image_features2.size(1))[None, :].repeat( cur_image_features2.size(0), 1), 'agg_weight': cur_image_features2.new_ones(cur_image_features2.size(0), cur_image_features2.size(1), 1), 'mask': None} cur_token_dict2 = self.get_model().block2(self.get_model().ctm2(token_dict)) cluster_image_features.append(cur_token_dict2["x"]) image_features = torch.cat(cluster_image_features, dim=1) image_features = image_features.to(self.get_model().mm_projector.weight.dtype) else: if input_type == "video": image_features, cls_features = torch.mean(image_features, dim=0, keepdim=False).unsqueeze( 0), torch.mean(image_features, dim=1, keepdim=False).unsqueeze(0) image_features = torch.cat([image_features, cls_features], dim=1) image_features = self.get_model().mm_projector(image_features) return image_features # 不同的type形状相同 def prepare_inputs_labels_for_multimodal( self, input_ids, attention_mask, past_key_values, labels, images, audio_features=None, target_frame=0, ref_ids=None ): IMAGE_TOKEN_INDEX = -200 AUDIO_TOKEN_INDEX = -300 # print("\n调用prepare_inputs_labels_for_multimodal") vision_tower = self.get_vision_tower() # print("获取vision_tower") num_frames = images[0].shape[0] # T if vision_tower is None or images is None or input_ids.shape[1] == 1: if past_key_values is not None and vision_tower is not None and images is not None and input_ids.shape[1] == 1: attention_mask = torch.ones((attention_mask.shape[0], past_key_values[-1][-1].shape[-2] + 1), dtype=attention_mask.dtype, device=attention_mask.device) return input_ids, attention_mask, past_key_values, None, labels if ref_ids is not None: ref_embeds = [] for ref_id in ref_ids: ref_embed = self.get_model().embed_tokens(ref_id) #[L, 4096] ref_embeds.append(ref_embed) # list[B]: [len_ref, 4096] if type(images) is list or images.ndim == 5: # print("先concat列表中的图像") concat_images = torch.cat([image for image in images], dim=0) # [BT, 3, H, W] org_image_features = self.encode_images(concat_images) # [BT, 256, 1024] # if audio_features is not None and hasattr(self, "audio_adapter"): if True: # image_features = self.audio_adapter(org_image_features, audio_features, ref_embeds_T) # image_features = self.token_compressor(org_image_features, ref_embeds) # print("image_features after compress:", image_features.shape) image_features = org_image_features else: image_features = org_image_features # split_sizes = [image.shape[0] for image in images] split_sizes = 1 image_features = torch.split(image_features, split_sizes, dim=0) # list[BT]: [1, 256,1024] image_features = [x.flatten(0, 1) for x in image_features] # list[BT]: [256,1024] org_image_features = torch.split(org_image_features, split_sizes, dim=0) org_image_features = [x.flatten(0, 1) for x in org_image_features] else: # print("直接获取image_feature") image_features = self.encode_images(images) org_image_features = image_features new_input_embeds = [] new_labels = [] if labels is not None else None cur_image_idx = 0 for batch_idx, cur_input_ids in enumerate(input_ids): # cur_image_idx += 1 # 判断当前input_id中有没有图像token # print("cur_input_ids shape:", cur_input_ids.shape) # print("cur_input_ids:", cur_input_ids) if (cur_input_ids == IMAGE_TOKEN_INDEX).sum() == 0: # print("input_ids中没有 IMAGE token") # multimodal LLM, but the current sample is not multimodal # 直接把input_ids进行text embed cur_input_embeds = self.get_model().embed_tokens(cur_input_ids) cur_input_embeds = cur_input_embeds + ( 0. * self.get_model().mm_projector(vision_tower.dummy_feature)).sum() new_input_embeds.append(cur_input_embeds) if labels is not None: new_labels.append(labels[batch_idx]) cur_image_idx += 1 continue image_token_indices = torch.where((cur_input_ids == IMAGE_TOKEN_INDEX)|(cur_input_ids == AUDIO_TOKEN_INDEX))[0] audio_token_indices = torch.where(cur_input_ids == AUDIO_TOKEN_INDEX)[0] # print("audio indices:", audio_token_indices) # print("image and audio indices:", image_token_indices) cur_new_input_embeds = [] if labels is not None: cur_labels = labels[batch_idx] cur_new_labels = [] assert cur_labels.shape == cur_input_ids.shape # 有多个image token--------------------------------------------- if len(image_token_indices) > 1: # print("有多个image token") # return 0 temp = [] cur, pre = image_token_indices[0], image_token_indices[0] # 这里是把连续的的位置放到一个list中存储 分割开的 for i in image_token_indices: cur = i # 如果下一个就在上一个之后 if cur - pre == 1: temp[-1] = temp[-1] + [cur] else: temp.append([cur]) pre = cur pre_image_token_end = 0 cur_frames = 0 for i in temp: # 第一个以及最后一个的位置 image_token_start = i[0] image_token_end = i[-1] cur_image_features = [] if len(i) >= 2: # 处理T个image组成的视频特征 for frame_idx in range(num_frames): cur_image_features.append(org_image_features[batch_idx*num_frames+frame_idx]) # print(batch_idx*num_frames+frame_idx) elif i[0] not in audio_token_indices: cur_image_features.append(org_image_features[batch_idx * num_frames + target_frame]) # print(batch_idx * num_frames + target_frame) else: cur_image_features.append(audio_features[batch_idx]) # print(f"audio{batch_idx}") # ------------------------------------------------------------------ # # i是每组的indices 根据其数量从image_features中拿特征 # for _ in i: # # 表示处理的是 # if _ not in audio_token_indices: # # 单个image # if cur_frames == num_frames: # # cur_image_features.append(org_image_features[cur_image_idx-num_frames+target_frame]) # cur_image_features.append(org_image_features[batch_idx*num_frames+target_frame]) # # print(cur_image_idx-num_frames+target_frame) # # 多个image # else: # cur_image_features.append(image_features[cur_image_idx]) # # print(cur_image_idx) # cur_image_idx += 1 # cur_frames += 1 # # 处理