Unify_dataset / modeling_vlm.py
asdjghh's picture
Upload modeling_vlm.py with huggingface_hub
dd6daa0 verified
# Copyright (c) 2023-2024 DeepSeek.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of
# the Software, and to permit persons to whom the Software is furnished to do so,
# subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR
# COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER
# IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN
# CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
from math import e
import torch
from attrdict import AttrDict
from einops import rearrange
from transformers import (
AutoConfig,
AutoModelForCausalLM,
LlamaConfig,
LlamaForCausalLM,
PreTrainedModel,
)
from transformers.modeling_outputs import CausalLMOutputWithPast
from torch.nn import CrossEntropyLoss
from transformers.configuration_utils import PretrainedConfig
from janus.models.clip_encoder import CLIPVisionTower
from janus.models.projector import MlpProjector
class vision_head(torch.nn.Module):
def __init__(self, params):
super().__init__()
self.output_mlp_projector = torch.nn.Linear(
params.n_embed, params.image_token_embed
)
self.vision_activation = torch.nn.GELU()
self.vision_head = torch.nn.Linear(
params.image_token_embed, params.image_token_size
)
def forward(self, x):
x = self.output_mlp_projector(x)
x = self.vision_activation(x)
x = self.vision_head(x)
return x
def model_name_to_cls(cls_name):
if "MlpProjector" in cls_name:
cls = MlpProjector
elif "CLIPVisionTower" in cls_name:
cls = CLIPVisionTower
elif "VQ" in cls_name:
from janus.models.vq_model import VQ_models
cls = VQ_models[cls_name]
elif "vision_head" in cls_name:
cls = vision_head
else:
raise ValueError(f"class_name {cls_name} is invalid.")
return cls
class VisionConfig(PretrainedConfig):
model_type = "vision"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class AlignerConfig(PretrainedConfig):
model_type = "aligner"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class GenVisionConfig(PretrainedConfig):
model_type = "gen_vision"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class GenAlignerConfig(PretrainedConfig):
model_type = "gen_aligner"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
class GenHeadConfig(PretrainedConfig):
model_type = "gen_head"
cls: str = ""
params: AttrDict = {}
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.cls = kwargs.get("cls", "")
if not isinstance(self.cls, str):
self.cls = self.cls.__name__
self.params = AttrDict(kwargs.get("params", {}))
from dataclasses import dataclass
@dataclass
class VLChatProcessorOutput():
sft_format: str
input_ids: torch.Tensor
pixel_values: torch.Tensor
num_image_tokens: torch.IntTensor
def __len__(self):
return len(self.input_ids)
class MultiModalityConfig(PretrainedConfig):
model_type = "multi_modality"
vision_config: VisionConfig
aligner_config: AlignerConfig
gen_vision_config: GenVisionConfig
gen_aligner_config: GenAlignerConfig
gen_head_config: GenHeadConfig
language_config: LlamaConfig
def __init__(self, **kwargs):
super().__init__(**kwargs)
vision_config = kwargs.get("vision_config", {})
self.vision_config = VisionConfig(**vision_config)
aligner_config = kwargs.get("aligner_config", {})
self.aligner_config = AlignerConfig(**aligner_config)
gen_vision_config = kwargs.get("gen_vision_config", {})
self.gen_vision_config = GenVisionConfig(**gen_vision_config)
gen_aligner_config = kwargs.get("gen_aligner_config", {})
self.gen_aligner_config = GenAlignerConfig(**gen_aligner_config)
gen_head_config = kwargs.get("gen_head_config", {})
self.gen_head_config = GenHeadConfig(**gen_head_config)
language_config = kwargs.get("language_config", {})
if isinstance(language_config, LlamaConfig):
self.language_config = language_config
else:
self.language_config = LlamaConfig(**language_config)
class MultiModalityPreTrainedModel(PreTrainedModel):
config_class = MultiModalityConfig
base_model_prefix = "multi_modality"
_no_split_modules = []
_skip_keys_device_placement = "past_key_values"
class MultiModalityCausalLM(MultiModalityPreTrainedModel):
def __init__(self, config: MultiModalityConfig):
super().__init__(config)
vision_config = config.vision_config
vision_cls = model_name_to_cls(vision_config.cls)
self.vision_model = vision_cls(**vision_config.params)
aligner_config = config.aligner_config
aligner_cls = model_name_to_cls(aligner_config.cls)
self.aligner = aligner_cls(aligner_config.params)
gen_vision_config = config.gen_vision_config
gen_vision_cls = model_name_to_cls(gen_vision_config.cls)
self.gen_vision_model = gen_vision_cls()
gen_aligner_config = config.gen_aligner_config
gen_aligner_cls = model_name_to_cls(gen_aligner_config.cls)
self.gen_aligner = gen_aligner_cls(gen_aligner_config.params)
gen_head_config = config.gen_head_config
gen_head_cls = model_name_to_cls(gen_head_config.cls)
self.gen_head = gen_head_cls(gen_head_config.params)
self.gen_embed = torch.nn.Embedding(
gen_vision_config.params.image_token_size, gen_vision_config.params.n_embed
)
language_config = config.language_config
self.language_model = LlamaForCausalLM(language_config)
def prepare_inputs_embeds(
self,
input_ids: torch.LongTensor,
pixel_values: torch.FloatTensor,
images_seq_mask: torch.LongTensor=None,
images_emb_mask: torch.LongTensor=None,
**kwargs,
):
"""
Args:
input_ids (torch.LongTensor): [b, T]
pixel_values (torch.FloatTensor): [b, n_images, 3, h, w]
images_seq_mask (torch.BoolTensor): [b, T]
images_emb_mask (torch.BoolTensor): [b, n_images, n_image_tokens]
assert torch.sum(images_seq_mask) == torch.sum(images_emb_mask)
Returns:
input_embeds (torch.Tensor): [b, T, D]
"""
# bs, n = pixel_values.shape[0:2]
# images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# # [b x n, T2, D]
# images_embeds = self.aligner(self.vision_model(images))
#
# # [b x n, T2, D] -> [b, n x T2, D]
# images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# # [b, n, T2] -> [b, n x T2]
# # images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
#
# # [b, T, D]
# # input_ids[input_ids < 0] = 0 # ignore the image embeddings
# inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
#
# # replace with the image embeddings
# # inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
#
# return inputs_embeds, images_embeds
bs, n = pixel_values.shape[0:2]
print('px.shape', pixel_values.shape)
images = rearrange(pixel_values, "b n c h w -> (b n) c h w")
# [b x n, T2, D]
images_embeds = self.aligner(self.vision_model(images))
# [b x n, T2, D] -> [b, n x T2, D]
images_embeds = rearrange(images_embeds, "(b n) t d -> b (n t) d", b=bs, n=n)
# [b, n, T2] -> [b, n x T2]
images_emb_mask = rearrange(images_emb_mask, "b n t -> b (n t)")
# [b, T, D]
input_ids[input_ids < 0] = 0 # ignore the image embeddings
inputs_embeds = self.language_model.get_input_embeddings()(input_ids)
# replace with the image embeddings
print('input_ids' ,input_ids.shape)
print('images_seq_mask ',images_seq_mask.shape)
print('inputs_embeds ',inputs_embeds.shape)
print('images_embeds ',images_embeds.shape)
print('images_emb_mask ',images_emb_mask.shape)
inputs_embeds[images_seq_mask] = images_embeds[images_emb_mask]
return inputs_embeds
def prepare_gen_img_embeds(self, image_ids: torch.LongTensor):
return self.gen_aligner(self.gen_embed(image_ids))
def forward(self,vl_chat_processor,
input_ids, labels=None, task="understanding", return_dict=True, pixel_values=None, images_seq_mask=None, images_emb_mask=None, **kwargs):
if task == "understanding":
inputs_embeds = self.prepare_inputs_embeds(input_ids, pixel_values, images_seq_mask, images_emb_mask)
return self.language_model.forward(
inputs_embeds=inputs_embeds,
labels=labels,
**kwargs
)
elif task == "generation":
print('LLLLLLLLLLL ',pixel_values)
print(kwargs)
image_token_num_per_image = 576
cfg_weight = 5
temperature = 1
tokens = torch.zeros((2*input_ids.size(0), input_ids.size(1)), dtype=torch.int).cuda()
for i in range(2):
tokens[i*input_ids.size(0):(i+1)*input_ids.size(0), :] = input_ids
if i % 2 != 0:
tokens[i*input_ids.size(0):(i+1)*input_ids.size(0), 1:-1] = 100015 # pad_id
inputs_embeds = self.language_model.get_input_embeddings()(tokens)
generated_tokens = torch.zeros((2*input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda()
outputs = self.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=None, labels=labels)
hidden_states = outputs.last_hidden_state
logits = self.gen_head(hidden_states)
logits_cond = logits[0::2, :]
logits_uncond = logits[1::2, :]
all_logits = logits_uncond + cfg_weight * (logits_cond - logits_uncond)
loss_fct = CrossEntropyLoss()
shift_logits = all_logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size)
if labels is not None:
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
else:
loss = None
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
elif task == "generation_direct":
outputs = self.language_model.model(input_ids=input_ids, **kwargs)
hidden_states = outputs[0] # possibly outputs[0]
logits = self.gen_head(hidden_states)
loss = None
logits = logits.float()
# Shift so that tokens < n predict n
shift_logits = logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size)
if labels is not None:
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
loss_fct = CrossEntropyLoss()
shift_labels = shift_labels.view(-1)
# Enable model parallelism
shift_labels = shift_labels.to(shift_logits.device)
loss = loss_fct(shift_logits, shift_labels)
else:
loss = None
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
elif task == "image_editing":
# image_token_num_per_image = 576
# img_size = 384
# patch_size = 16
# cfg_weight = kwargs.get('cfg_weight', 5)
# cfg_weight2 = kwargs.get('cfg_weight2', 5)
# temperature = kwargs.get('temperature', 1.0)
# parallel_size = kwargs.get('parallel_size', input_ids.size(0))
#
# # 构造tokens: 每个输入生成3个版本 (cond_full, cond_part, uncond)
# tokens = torch.zeros((3 * input_ids.size(0), input_ids.size(1)), dtype=torch.int).cuda()
# pre_data = []
# img_len = len(kwargs['source_image'])
#
# # 处理输入图像
# import PIL.Image
# images = [PIL.Image.open(image_path).convert("RGB") for image_path in kwargs['source_image']]
# encoder_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")['pixel_values']
#
# # 为每个样本构造3种条件的tokens
# for i in range(3 * input_ids.size(0)):
# tokens[i, :] = input_ids[i // 3, :]
# if i % 3 == 2: # uncond版本,用pad_id替换中间tokens
# tokens[i, 1:-1] = 100015 # pad_id
#
# # 添加数据到pre_data
# pre_data.append(VLChatProcessorOutput(
# sft_format=kwargs['sft_format'][i // 3],
# pixel_values=encoder_pixel_values[i // 3, :],
# input_ids=tokens[i - 2],
# num_image_tokens=[vl_chat_processor.num_image_tokens] * 1
# ))
# pre_data.append(VLChatProcessorOutput(
# sft_format=kwargs['sft_format'][i // 3],
# pixel_values=encoder_pixel_values[i // 3, :],
# input_ids=tokens[i - 1],
# num_image_tokens=[vl_chat_processor.num_image_tokens] * 1
# ))
# pre_data.append(VLChatProcessorOutput(
# sft_format=kwargs['sft_format'][i // 3],
# pixel_values=None,
# input_ids=tokens[i],
# num_image_tokens=[]
# ))
#
# # 批处理输入数据
# prepare_inputs = vl_chat_processor.batchify(pre_data)
#
# # 准备输入embeddings
# inputs_embeds = self.prepare_inputs_embeds(
# input_ids=tokens.cuda(),
# pixel_values=prepare_inputs['pixel_values'].to(torch.bfloat16).cuda(),
# images_emb_mask=prepare_inputs['images_emb_mask'].cuda(),
# images_seq_mask=prepare_inputs['images_seq_mask'].cuda()
# )
#
# # 处理输入图像的编码
# input_image_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")[
# 'pixel_values'].to(torch.bfloat16).cuda()
# quant_input, emb_loss_input, info_input = self.gen_vision_model.encode(input_image_pixel_values)
# image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
# image_embeds_input = self.prepare_gen_img_embeds(image_tokens_input)
#
# # 将输入图像embeddings插入到正确位置
# ppp = (tokens == 100580).nonzero() # 找到图像token位置
# for ii, ind in enumerate(ppp):
# if ii % 4 == 0:
# offset = ind[1] + 2
# inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[
# (ii // 2) % img_len]
#
# # **训练模式:只计算loss,不生成图像**
# labels = None
# if labels is not None:
# outputs = self.language_model.model(
# inputs_embeds=inputs_embeds,
# use_cache=True,
# past_key_values=None
# )
# hidden_states = outputs.last_hidden_state
# logits = self.gen_head(hidden_states)
#
# # 分离三种条件的logits
# logit_cond_full = logits[0::3, :]
# logit_cond_part = logits[1::3, :]
# logit_uncond = logits[2::3, :]
#
# # 计算组合logits
# logit_cond = (logit_cond_full + cfg_weight2 * logit_cond_part) / (1 + cfg_weight2)
# all_logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
#
# # 计算loss
# loss_fct = CrossEntropyLoss()
# shift_logits = all_logits[..., :-1, :].contiguous()
# shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size)
# shift_labels = labels[..., 1:].contiguous().view(-1).to(shift_logits.device)
# loss = loss_fct(shift_logits, shift_labels)
#
# if not return_dict:
# output = (all_logits,) + outputs[1:]
# return ((loss,) + output) if loss is not None else output
#
# return CausalLMOutputWithPast(
# loss=loss,
# logits=all_logits,
# past_key_values=outputs.past_key_values,
# hidden_states=outputs.hidden_states,
# attentions=outputs.attentions,
# )
#
# # **推理模式:自回归生成图像**
# else:
# import numpy as np
# with torch.inference_mode():
# generated_tokens = torch.zeros((parallel_size, image_token_num_per_image), dtype=torch.int).cuda()
# outputs = None
#
# # 自回归生成循环
# for i in range(image_token_num_per_image):
# # 前向传播
# outputs = self.language_model.model(
# inputs_embeds=inputs_embeds,
# use_cache=True,
# past_key_values=outputs.past_key_values if i != 0 else None
# )
# hidden_states = outputs.last_hidden_state
#
# # 获取最后一个token的logits
# logits = self.gen_head(hidden_states[:, -1, :])
#
# # 分离三种条件的logits
# logit_cond_full = logits[0::3, :]
# logit_cond_part = logits[1::3, :]
# logit_uncond = logits[2::3, :]
#
# # 计算组合logits
# logit_cond = (logit_cond_full + cfg_weight2 * logit_cond_part) / (1 + cfg_weight2)
# combined_logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
#
# # 采样下一个token
# probs = torch.softmax(combined_logits / temperature, dim=-1)
# next_token = torch.multinomial(probs, num_samples=1)
# generated_tokens[:, i] = next_token.squeeze(dim=-1)
#
# # 为下一步准备输入embeddings
# if i < image_token_num_per_image - 1:
# # 扩展next_token到3个副本
# next_token_expanded = torch.cat([
# next_token.unsqueeze(dim=1),
# next_token.unsqueeze(dim=1),
# next_token.unsqueeze(dim=1)
# ], dim=1).view(-1)
#
# # 获取下一个token的embeddings
# img_embeds = self.prepare_gen_img_embeds(next_token_expanded)
# inputs_embeds = img_embeds.unsqueeze(dim=1)
#
# # 解码生成的tokens为图像
# dec = self.gen_vision_model.decode_code(
# generated_tokens.to(dtype=torch.int),
# shape=[parallel_size, 8, img_size // patch_size, img_size // patch_size]
# )
# dec = dec.to(torch.float32).cpu().numpy().transpose(0, 2, 3, 1)
# dec = np.clip((dec + 1) / 2 * 255, 0, 255)
#
# # 构造输出图像数组,确保形状正确
# visual_img = np.zeros((parallel_size, img_size, img_size, 3), dtype=np.uint8)
# visual_img[:, :, :] = dec
#
# # 可选:保存调试图像(仅在推理模式下)
# # if kwargs.get('save_debug_image', False):
# # # 确保保存的是单张图像,形状为(384, 384, 3)
# debug_img = visual_img[0] if visual_img.shape[0] > 0 else visual_img
# PIL.Image.fromarray(debug_img).save('/home/ps/Bxh/align-anything/debug_output.png')
image_token_num_per_image = 576
img_size = 384
patch_size = 16
cfg_weight = 5
temperature = 1
tokens = torch.zeros((3 * input_ids.size(0), input_ids.size(1)), dtype=torch.int).cuda()
pre_data = []
img_len = len(kwargs['source_image'])
# print(kwargs['source_image'].size(0))
print(kwargs['source_image'])
print(len(kwargs['source_image'][0]))
import PIL.Image
images = [PIL.Image.open(image_path).convert("RGB") for image_path in kwargs['source_image']]
# images = [PIL.Image.open(image_path).convert("RGB") for image_path in kwargs['source_image']]
print('len_images : ',len(images))
encoder_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")['pixel_values']
print(encoder_pixel_values.shape)
print(encoder_pixel_values[0].shape)
print(encoder_pixel_values[0][0][0][:2])
# print((encoder_pixel_values[0]!= encoder_pixel_values[1]).sum())
# print((encoder_pixel_values[0] != encoder_pixel_values[2]).sum())
# print((encoder_pixel_values[0] != encoder_pixel_values[3]).sum())
for i in range(3 * input_ids.size(0)):
print(input_ids.shape)
print(input_ids.size(0))
tokens[i * input_ids.size(0):(i + 1) * input_ids.size(0),:] = input_ids[i // 3,:]
if i % 3 == 2:
tokens[i * input_ids.size(0):(i + 1) * input_ids.size(0), 1:-1] = 100002
print(encoder_pixel_values[i//3,:].shape)
print(len(kwargs['sft_format'][i//3]))
print(tokens[i].shape)
pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=encoder_pixel_values[i//3,:],
input_ids=tokens[i - 2],
num_image_tokens=[vl_chat_processor.num_image_tokens] * 1))
pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=encoder_pixel_values[i//3,:],
input_ids=tokens[i - 1],
num_image_tokens=[vl_chat_processor.num_image_tokens] * 1))
pre_data.append(VLChatProcessorOutput(sft_format=kwargs['sft_format'][i//3], pixel_values=None, input_ids=tokens[i],
num_image_tokens=[]))
# print(tokens.shape)
# _, src_image = self.prepare_inputs_embeds(tokens[0], kwargs['source_image'])
ppp = (tokens == 100580).nonzero()
# print(tokens[0][583],tokens[0][584],tokens[0][576],tokens[0][577])
# print(input_ids.size(0))
# print(tokens[0][2], tokens[0][3])
# print(tokens[0][1161], tokens[0][1162])
# print(ppp)
# print(src_image.shape)
# img_len = src_image.shape[0]
# # inputs_embeds_2 = self.language_model.get_input_embeddings()(tokens[1])
# # inputs_embeds_3 = self.language_model.get_input_embeddings()(tokens[2])
# inputs_embeds = self.language_model.get_input_embeddings()(tokens)
# print(inputs_embeds.shape)
prepare_inputs = vl_chat_processor.batchify(pre_data)
# print('prepare_inputs pixel_values', prepare_inputs['pixel_values'].shape)
# print('prepare_inputs images_emb_mask', prepare_inputs['images_emb_mask'].shape)
# print('prepare_inputs images_seq_mask', prepare_inputs['images_seq_mask'].shape)
inputs_embeds = self.prepare_inputs_embeds(
input_ids=tokens.cuda(),
pixel_values=prepare_inputs['pixel_values'].to(torch.bfloat16).cuda(),
images_emb_mask=prepare_inputs['images_emb_mask'].cuda(),
images_seq_mask=prepare_inputs['images_seq_mask'].cuda()
)
input_image_pixel_values = vl_chat_processor.image_processor(images, return_tensors="pt")['pixel_values'].to(torch.bfloat16).cuda()
quant_input, emb_loss_input, info_input = self.gen_vision_model.encode(input_image_pixel_values)
image_tokens_input = info_input[2].detach().reshape(input_image_pixel_values.shape[0], -1)
image_embeds_input = self.prepare_gen_img_embeds(image_tokens_input)
# print('image_embeds_input', image_embeds_input.shape)
# print('inputs_embeds', inputs_embeds.shape)
for ii, ind in enumerate(ppp):
# print('nmsl: ', ii, ind)
if ii % 4 == 0:
offset = ind[1] + 2
inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[ii // 4]
generated_tokens = torch.zeros((3 * input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda()
outputs = self.language_model.model(inputs_embeds=inputs_embeds, use_cache=True, past_key_values=None,
labels=labels)
hidden_states = outputs.last_hidden_state
print('HHHH',hidden_states.shape)
# torch.save(inputs_embeds, '/data/bxh_data/unify_model/hidden_states.pt')
logits = self.gen_head(hidden_states)
print('logits.shape', logits.shape) # [3, 1760, 16384])
print(labels.shape) # [3, 1760]
# logits_cond = logits[0::2, :]
# logits_uncond = logits[1::2, :]
logit_cond_full = logits[0::3, :]
logit_cond_part = logits[1::3, :]
logit_uncond = logits[2::3, :]
cfg_weight2 = 5
logit_cond = (logit_cond_full + cfg_weight2 * (logit_cond_part)) / (1 + cfg_weight2)
all_logits = logit_uncond + cfg_weight * (logit_cond - logit_uncond)
# all_logits = logits_uncond + cfg_weight * (logits_cond - logits_uncond)
loss_fct = CrossEntropyLoss()
shift_logits = all_logits[..., :-1, :].contiguous()
shift_logits = shift_logits.view(-1, self.config.gen_head_config.params.image_token_size)
if labels is not None:
shift_labels = labels[..., 1:].contiguous()
shift_labels = shift_labels.view(-1)
shift_labels = shift_labels.to(shift_logits.device)
print(shift_logits.shape, shift_labels.shape)
loss = loss_fct(shift_logits, shift_labels)
else:
loss = None
if not return_dict:
output = (logits,) + outputs[1:]
return ((loss,) + output) if loss is not None else output
return CausalLMOutputWithPast(
loss=loss,
logits=logits,
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
AutoConfig.register("vision", VisionConfig)
AutoConfig.register("aligner", AlignerConfig)
AutoConfig.register("gen_vision", GenVisionConfig)
AutoConfig.register("gen_aligner", GenAlignerConfig)
AutoConfig.register("gen_head", GenHeadConfig)
AutoConfig.register("multi_modality", MultiModalityConfig)
AutoModelForCausalLM.register(MultiModalityConfig, MultiModalityCausalLM)