mic21_model / modeling_mic21.py
jkralev's picture
Upload modeling_mic21.py with huggingface_hub
e670039 verified
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import pdb
from transformers import OffloadedCache,DynamicCache
from .configuration_mic21 import MIC21SummarizerConfig
import numpy as np
from transformers import AutoImageProcessor, ResNetForImageClassification
class MIC21SummarizerModel(PreTrainedModel):
config_class = MIC21SummarizerConfig
is_parallelizable = True
model_parallel = True
place_model_on_device = False
model_wrapped = {}
def init_components(self):
self.components["image_model"] = ResNetForImageClassification.from_pretrained(self.hf_config.hf_image_model).cuda()
self.components["image_processor"] = AutoImageProcessor.from_pretrained(self.hf_config.hf_image_model)
self.components["llm"] = AutoModelForCausalLM.from_pretrained(self.hf_config.hf_text_model,torch_dtype=torch.float16).cuda()
self.components["tokenizer"] = AutoTokenizer.from_pretrained(self.hf_config.hf_text_model)
for param in self.components["image_model"].parameters():
param.requires_grad = False
for param in self.components["llm"].parameters():
param.requires_grad = False
def __init__(self,config):
super().__init__(config)
#Init Image Processing Model
self.components = {"image_model":None,"llm":None,"tokenizer":None,"image_processor":None}
self.hf_config = config
#self.components["image_model"] = ResNetForImageClassification.from_pretrained(config.hf_image_model,device_map=f"cuda:{config.im_model_cuda_id}")
#self.components["image_model"] = ResNetForImageClassification.from_pretrained(config.hf_image_model).cpu().cuda()
#self.components["image_processor"] = AutoImageProcessor.from_pretrained(config.hf_image_model)
#self.components["llm"] = AutoModelForCausalLM.from_pretrained(config.hf_text_model,torch_dtype=torch.float16).cpu().cuda()
#self.quantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16)
#self.components["llm"] = AutoModelForCausalLM.from_pretrained(
# config.hf_text_model,
# device_map=config.device_map,
# max_memory=config.memory_map,
# torch_dtype=torch.float16,#config.text_model_dtype,
# attn_implementation=config.attn_implementation,
# #quantization_config=self.quantization_config
#)
#self.components["tokenizer"] = AutoTokenizer.from_pretrained(config.hf_text_model)
#self.in_device = config.in_device
#self.out_device = config.out_device
#self.projection_layer = torch.nn.Linear(49, self.components["llm"].config.hidden_size, dtype=torch.float, device=f"cuda:{self.in_device}")
self.projection_layer = torch.nn.Linear(49, 2048, dtype=torch.float)
#self.projection_norm = torch.nn.LayerNorm(49, eps=1e-5, bias=True, device=f"cuda:{self.in_device}")
self.projection_norm = torch.nn.LayerNorm(49, eps=1e-5, bias=True)
self.projection_dropout = torch.nn.Dropout(0.1)
self.im_model_cuda_id = config.im_model_cuda_id
self.output_length = config.output_length
def forward(self, images, titles):
prepared_images = self.components["image_processor"](images,return_tensors="pt")
prepared_images["pixel_values"] = prepared_images["pixel_values"].cuda()
#prepared_images = prepared_images.to(f"cuda:{self.im_model_cuda_id}")
img_features = self.components["image_model"](**prepared_images,output_hidden_states=True)
img_features = img_features["hidden_states"][-1]
(batch_size,nfilter,nx,ny)=img_features.shape
img_features = img_features.view(batch_size,nfilter,nx*ny)
messages = [
{"role":"system","content":"Generate title and description for the provided image. The image features are: "},
{"role":"user","content":"Generate a title:"}]
tokenized_messages = self.components["tokenizer"].apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").cuda()
#.to(self.in_device)
vectorized_messages = self.components["llm"].model.embed_tokens(tokenized_messages[0]).unsqueeze(0)
vectorized_messages = vectorized_messages.repeat(batch_size,1,1)
#.to(self.in_device)
first_eos_index = (tokenized_messages[0]==self.components["tokenizer"].eos_token_id).nonzero()[0].item()
#img_features = img_features.to(f"cuda:{self.in_device}")
visual_embeddings = self.projection_layer(self.projection_dropout(self.projection_norm(img_features[:,0:256,:])))
#visual_embeddings.half().to(self.in_device)
combined_embeds = torch.cat([
vectorized_messages[:,:first_eos_index-1,:],
visual_embeddings.half(),
vectorized_messages[:,first_eos_index:,:]],dim=1)
#combined_embeds = torch.cat([self.input_emb, self.eot_emb],dim=1)
self.cache = OffloadedCache()
#self.cache = DynamicCache()
outputs = self.components["llm"](inputs_embeds=combined_embeds,past_key_values=self.cache,use_cache=True)
logits = outputs.logits[:,-1]
out_logits = logits.unsqueeze(1)
new_tok = torch.argmax(logits,dim=-1)
if self.output_length is None:
max_len = 64
else:
max_len = self.output_length
for k in range(0,max_len):
outputs = self.components["llm"](input_ids=new_tok.unsqueeze(0).permute(1,0),past_key_values=self.cache,use_cache=True)
logits = outputs.logits[:,-1]
if out_logits is None:
out_logits = logits.unsqueeze(1)
else:
out_logits = torch.cat([out_logits,logits.unsqueeze(1)],dim=1)
new_tok = torch.argmax(logits,dim=-1)
if max_len is None and new_tok.item() == self.components["tokenizer"].eos_token_id:
break
if titles is not None:
target_tok = self.components["tokenizer"](titles, add_special_tokens=False, max_length=max_len+1, padding='max_length')
loss = torch.nn.CrossEntropyLoss()(out_logits.permute((0,2,1)), torch.LongTensor(target_tok["input_ids"]).cuda())
#.cuda(self.out_device))
return {"loss": loss, "logits": logits, "eval_loss": loss}
return {"logits":out_logits}