File size: 6,605 Bytes
03ea185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
785dffc
 
920c442
 
785dffc
920c442
 
9fba37d
 
 
 
 
 
03ea185
 
 
 
 
920c442
03ea185
785dffc
e0030ff
785dffc
03ea185
785dffc
e0030ff
03ea185
 
 
 
 
 
 
 
 
785dffc
03ea185
 
 
 
 
3a58ec5
03ea185
 
686e2c9
03ea185
 
 
 
 
 
90540a4
 
03ea185
 
 
8561911
03ea185
 
 
 
 
 
 
b95cf18
03ea185
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a70327b
03ea185
e670039
03ea185
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from transformers.modeling_utils import PreTrainedModel
from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import pdb
from transformers import OffloadedCache,DynamicCache
from .configuration_mic21 import MIC21SummarizerConfig
import numpy as np
from transformers import AutoImageProcessor, ResNetForImageClassification

class MIC21SummarizerModel(PreTrainedModel):
    config_class = MIC21SummarizerConfig
    is_parallelizable = True
    model_parallel = True
    place_model_on_device = False
    model_wrapped = {}

    def init_components(self):
        self.components["image_model"] = ResNetForImageClassification.from_pretrained(self.hf_config.hf_image_model).cuda()
        self.components["image_processor"] = AutoImageProcessor.from_pretrained(self.hf_config.hf_image_model)
        
        self.components["llm"] = AutoModelForCausalLM.from_pretrained(self.hf_config.hf_text_model,torch_dtype=torch.float16).cuda()
        self.components["tokenizer"] = AutoTokenizer.from_pretrained(self.hf_config.hf_text_model)

        for param in self.components["image_model"].parameters():
            param.requires_grad = False

        for param in self.components["llm"].parameters():
            param.requires_grad = False
    
    def __init__(self,config):
        super().__init__(config)
        #Init Image Processing Model        
        self.components = {"image_model":None,"llm":None,"tokenizer":None,"image_processor":None}
        self.hf_config = config
        #self.components["image_model"] = ResNetForImageClassification.from_pretrained(config.hf_image_model,device_map=f"cuda:{config.im_model_cuda_id}")
        #self.components["image_model"] = ResNetForImageClassification.from_pretrained(config.hf_image_model).cpu().cuda()
        
        #self.components["image_processor"] = AutoImageProcessor.from_pretrained(config.hf_image_model)

        #self.components["llm"] = AutoModelForCausalLM.from_pretrained(config.hf_text_model,torch_dtype=torch.float16).cpu().cuda()
        
        #self.quantization_config = BitsAndBytesConfig(load_in_4bit=True,bnb_4bit_compute_dtype=torch.bfloat16)
        #self.components["llm"] = AutoModelForCausalLM.from_pretrained(
        #    config.hf_text_model,
        #    device_map=config.device_map,
        #    max_memory=config.memory_map,
        #    torch_dtype=torch.float16,#config.text_model_dtype,
        #    attn_implementation=config.attn_implementation,
        #    #quantization_config=self.quantization_config
        #)
        #self.components["tokenizer"] = AutoTokenizer.from_pretrained(config.hf_text_model)

        #self.in_device = config.in_device
        #self.out_device = config.out_device

        #self.projection_layer = torch.nn.Linear(49, self.components["llm"].config.hidden_size, dtype=torch.float, device=f"cuda:{self.in_device}")
        self.projection_layer = torch.nn.Linear(49, 2048, dtype=torch.float)
        
        #self.projection_norm = torch.nn.LayerNorm(49, eps=1e-5, bias=True, device=f"cuda:{self.in_device}")
        self.projection_norm = torch.nn.LayerNorm(49, eps=1e-5, bias=True)
        self.projection_dropout = torch.nn.Dropout(0.1)

        self.im_model_cuda_id = config.im_model_cuda_id
        self.output_length = config.output_length
        
    def forward(self, images, titles):
        prepared_images = self.components["image_processor"](images,return_tensors="pt")
        prepared_images["pixel_values"] = prepared_images["pixel_values"].cuda()
        #prepared_images = prepared_images.to(f"cuda:{self.im_model_cuda_id}")
        
        img_features = self.components["image_model"](**prepared_images,output_hidden_states=True)
        img_features = img_features["hidden_states"][-1]
        (batch_size,nfilter,nx,ny)=img_features.shape
        img_features = img_features.view(batch_size,nfilter,nx*ny)

        messages = [
            {"role":"system","content":"Generate title and description for the provided image. The image features are: "},
            {"role":"user","content":"Generate a title:"}]
        
        tokenized_messages = self.components["tokenizer"].apply_chat_template(messages, tokenize=True, add_generation_prompt=True, return_tensors="pt").cuda()
            #.to(self.in_device)     
        vectorized_messages = self.components["llm"].model.embed_tokens(tokenized_messages[0]).unsqueeze(0)
        vectorized_messages = vectorized_messages.repeat(batch_size,1,1)
            #.to(self.in_device)
        first_eos_index = (tokenized_messages[0]==self.components["tokenizer"].eos_token_id).nonzero()[0].item()

        #img_features = img_features.to(f"cuda:{self.in_device}")
        visual_embeddings = self.projection_layer(self.projection_dropout(self.projection_norm(img_features[:,0:256,:])))

        #visual_embeddings.half().to(self.in_device)
        combined_embeds = torch.cat([
            vectorized_messages[:,:first_eos_index-1,:], 
            visual_embeddings.half(), 
            vectorized_messages[:,first_eos_index:,:]],dim=1)

        #combined_embeds = torch.cat([self.input_emb, self.eot_emb],dim=1)
        self.cache = OffloadedCache()
        #self.cache = DynamicCache()
        
        outputs = self.components["llm"](inputs_embeds=combined_embeds,past_key_values=self.cache,use_cache=True)
        logits = outputs.logits[:,-1]
        out_logits = logits.unsqueeze(1)
        new_tok = torch.argmax(logits,dim=-1)

        if self.output_length is None:
            max_len = 64
        else:
            max_len = self.output_length
            
        for k in range(0,max_len):
            outputs = self.components["llm"](input_ids=new_tok.unsqueeze(0).permute(1,0),past_key_values=self.cache,use_cache=True)
            logits = outputs.logits[:,-1]
            if out_logits is None:
                out_logits = logits.unsqueeze(1)
            else:
                out_logits = torch.cat([out_logits,logits.unsqueeze(1)],dim=1)
            new_tok = torch.argmax(logits,dim=-1)
            if max_len is None and new_tok.item() == self.components["tokenizer"].eos_token_id:
                break
        if titles is not None:
            target_tok = self.components["tokenizer"](titles, add_special_tokens=False, max_length=max_len+1, padding='max_length')
            loss = torch.nn.CrossEntropyLoss()(out_logits.permute((0,2,1)), torch.LongTensor(target_tok["input_ids"]).cuda())
                #.cuda(self.out_device))
            return {"loss": loss, "logits": logits, "eval_loss": loss}
            
        return {"logits":out_logits}