|
|
from transformers.modeling_utils import PreTrainedModel |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
import torch |
|
|
import pdb |
|
|
from transformers import OffloadedCache,DynamicCache |
|
|
from .configuration_mic21 import MIC21SummarizerConfig |
|
|
import numpy as np |
|
|
from transformers import AutoImageProcessor, ResNetForImageClassification |
|
|
|
|
|
class MIC21SummarizerModel(PreTrainedModel): |
|
|
config_class = MIC21SummarizerConfig |
|
|
is_parallelizable = True |
|
|
model_parallel = True |
|
|
place_model_on_device = False |
|
|
model_wrapped = {} |
|
|
|
|
|
def init_components(self): |
|
|
self.components["image_model"] = ResNetForImageClassification.from_pretrained(self.hf_config.hf_image_model).cuda() |
|
|
self.components["image_processor"] = AutoImageProcessor.from_pretrained(self.hf_config.hf_image_model) |
|
|
|
|
|
self.components["llm"] = AutoModelForCausalLM.from_pretrained(self.hf_config.hf_text_model,torch_dtype=torch.float16).cuda() |
|
|
self.components["tokenizer"] = AutoTokenizer.from_pretrained(self.hf_config.hf_text_model) |
|
|
|
|
|
for param in self.components["image_model"].parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
for param in self.components["llm"].parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def __init__(self,config): |
|
|
super().__init__(config) |
|
|
|
|
|
self.components = {"image_model":None,"llm":None,"tokenizer":None,"image_processor":None} |
|
|
self.hf_config = config |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.projection_layer = torch.nn.Linear(49, 2048, dtype=torch.float) |
|
|
|
|
|
|
|
|
self.projection_norm = torch.nn.LayerNorm(49, eps=1e-5, bias=True) |
|
|
self.projection_dropout = torch.nn.Dropout(0.1) |
|
|
|
|
|
self.im_model_cuda_id = config.im_model_cuda_id |
|
|
self.output_length = config.output_length |
|
|
|
|
|
def forward(self, images, titles): |
|
|
prepared_images = self.components["image_processor"](images,return_tensors="pt") |
|
|
prepared_images["pixel_values"] = prepared_images["pixel_values"].cuda() |
|
|
|
|
|
|
|
|
img_features = self.components["image_model"](**prepared_images,output_hidden_states=True) |
|
|
img_features = img_features["hidden_states"][-1] |
|
|
(batch_size,nfilter,nx,ny)=img_features.shape |
|
|
img_features = img_features.view(batch_size,nfilter,nx*ny) |
|
|
|
|
|
messages = [ |
|
|
{"role":"system","content":"Generate title and description for the provided image. The image features are: "}, |
|
|
{"role":"user","content":"Generate a title:"}] |
|
|
|
|
|
tokenized_messages = self.components["tokenizer"].apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").cuda() |
|
|
|
|
|
vectorized_messages = self.components["llm"].model.embed_tokens(tokenized_messages[0]).unsqueeze(0) |
|
|
vectorized_messages = vectorized_messages.repeat(batch_size,1,1) |
|
|
|
|
|
first_eos_index = (tokenized_messages[0]==self.components["tokenizer"].eos_token_id).nonzero()[0].item() |
|
|
|
|
|
|
|
|
visual_embeddings = self.projection_layer(self.projection_dropout(self.projection_norm(img_features[:,0:256,:]))) |
|
|
|
|
|
|
|
|
combined_embeds = torch.cat([ |
|
|
vectorized_messages[:,:first_eos_index-1,:], |
|
|
visual_embeddings.half(), |
|
|
vectorized_messages[:,first_eos_index:,:]],dim=1) |
|
|
|
|
|
|
|
|
self.cache = OffloadedCache() |
|
|
|
|
|
|
|
|
outputs = self.components["llm"](inputs_embeds=combined_embeds,past_key_values=self.cache,use_cache=True) |
|
|
logits = outputs.logits[:,-1] |
|
|
out_logits = logits.unsqueeze(1) |
|
|
new_tok = torch.argmax(logits,dim=-1) |
|
|
|
|
|
if self.output_length is None: |
|
|
max_len = 64 |
|
|
else: |
|
|
max_len = self.output_length |
|
|
|
|
|
for k in range(0,max_len): |
|
|
outputs = self.components["llm"](input_ids=new_tok.unsqueeze(0).permute(1,0),past_key_values=self.cache,use_cache=True) |
|
|
logits = outputs.logits[:,-1] |
|
|
if out_logits is None: |
|
|
out_logits = logits.unsqueeze(1) |
|
|
else: |
|
|
out_logits = torch.cat([out_logits,logits.unsqueeze(1)],dim=1) |
|
|
new_tok = torch.argmax(logits,dim=-1) |
|
|
if max_len is None and new_tok.item() == self.components["tokenizer"].eos_token_id: |
|
|
break |
|
|
if titles is not None: |
|
|
target_tok = self.components["tokenizer"](titles, add_special_tokens=False, max_length=max_len+1, padding='max_length') |
|
|
loss = torch.nn.CrossEntropyLoss()(out_logits.permute((0,2,1)), torch.LongTensor(target_tok["input_ids"]).cuda()) |
|
|
|
|
|
return {"loss": loss, "logits": logits, "eval_loss": loss} |
|
|
|
|
|
return {"logits":out_logits} |
|
|
|