| from header import * |
| import os |
| import torch.nn.functional as F |
| from .ImageBind import * |
| from .ImageBind import data |
| from .modeling_llama import LlamaForCausalLM |
| from transformers import StoppingCriteria, StoppingCriteriaList |
|
|
| import torch |
| from torch.nn.utils import rnn |
|
|
| class StoppingCriteriaSub(StoppingCriteria): |
|
|
| def __init__(self, stops = [], encounters=1): |
| super().__init__() |
| self.stops = stops |
| self.ENCOUNTERS = encounters |
|
|
| def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor): |
| stop_count = 0 |
| for stop in self.stops: |
| stop_count = (stop == input_ids[0]).sum().item() |
| if stop_count >= self.ENCOUNTERS: |
| return True |
| return False |
|
|
| def build_one_instance(tokenizer, conversation): |
| text_list = [] |
| turn_num = len(conversation) |
| input_ids, target_ids = [], [] |
| for i in range(turn_num): |
| turn = conversation[i] |
| role = turn['from'] |
| if i == 0: |
| assert role == 'human' |
| text = '</Img> ' + turn['value'] + '\n### Assistant:' |
| one_input_id = tokenizer(text, add_special_tokens=False).input_ids |
| input_ids += one_input_id |
| target_ids += [-100]*len(one_input_id) |
| else: |
| if role == 'human': |
| text = 'Human: ' + turn['value'] + '\n### Assistant:' |
| one_input_id = tokenizer(text, add_special_tokens=False).input_ids |
| input_ids += one_input_id |
| target_ids += [-100]*len(one_input_id) |
| elif role == 'gpt': |
| text = turn['value'] + '\n###' |
| one_input_id = tokenizer(text, add_special_tokens=False).input_ids |
| input_ids += one_input_id |
| target_ids += one_input_id |
| else: |
| raise Exception('Wrong Role!!!') |
| text_list.append(text) |
| assert len(input_ids) == len(target_ids) |
| return text_list, input_ids, target_ids |
|
|
| def process_batch_instance(tokenizer, batch_of_conversations, max_tgt_len): |
| batch_input_ids, batch_target_ids = [], [] |
| for conversation in batch_of_conversations: |
| _, one_input_ids, one_target_ids = build_one_instance(tokenizer, conversation) |
| batch_input_ids.append(torch.LongTensor(one_input_ids)) |
| batch_target_ids.append(torch.LongTensor(one_target_ids)) |
| input_ids = rnn.pad_sequence(batch_input_ids, batch_first=True, padding_value=tokenizer.pad_token_id) |
| target_ids = rnn.pad_sequence(batch_target_ids, batch_first=True, padding_value=-100) |
| assert input_ids.size() == target_ids.size() |
| input_ids = input_ids[:,:max_tgt_len] |
| target_ids = target_ids[:,:max_tgt_len] |
| attention_mask = input_ids.ne(tokenizer.pad_token_id) |
| assert attention_mask.size() == input_ids.size() |
| return input_ids, target_ids, attention_mask.long() |
|
|
| PROMPT_START = '### Human: <Img>' |
| class OpenLLAMAPEFTModel(nn.Module): |
|
|
| '''LoRA for LLaMa model''' |
|
|
| def __init__(self, **args): |
| super(OpenLLAMAPEFTModel, self).__init__() |
| self.args = args |
| imagebind_ckpt_path = args['imagebind_ckpt_path'] |
| vicuna_ckpt_path = args['vicuna_ckpt_path'] |
| max_tgt_len = args['max_tgt_len'] |
| stage = args['stage'] |
|
|
| print (f'Initializing visual encoder from {imagebind_ckpt_path} ...') |
| self.visual_encoder, self.visual_hidden_size = \ |
| imagebind_model.imagebind_huge(pretrained=True, store_path=imagebind_ckpt_path) |
| |
| for name, param in self.visual_encoder.named_parameters(): |
| param.requires_grad = False |
| self.visual_encoder.eval() |
| print ('Visual encoder initialized.') |
|
|
| print (f'Initializing language decoder from {vicuna_ckpt_path} ...') |
| |
| peft_config = LoraConfig( |
| task_type=TaskType.CAUSAL_LM, |
| inference_mode=False, |
| r=self.args['lora_r'], |
| lora_alpha=self.args['lora_alpha'], |
| lora_dropout=self.args['lora_dropout'], |
| target_modules=['q_proj', 'k_proj', 'v_proj', 'o_proj'] |
| ) |
|
|
| self.llama_model = LlamaForCausalLM.from_pretrained(vicuna_ckpt_path, use_auth_token=os.environ['API_TOKEN']) |
| self.llama_model = get_peft_model(self.llama_model, peft_config) |
| self.llama_model.print_trainable_parameters() |
|
|
| self.llama_tokenizer = LlamaTokenizer.from_pretrained(vicuna_ckpt_path, use_fast=False, use_auth_token=os.environ['API_TOKEN']) |
| self.llama_tokenizer.pad_token = self.llama_tokenizer.eos_token |
| self.llama_tokenizer.padding_side = "right" |
| print ('Language decoder initialized.') |
|
|
| self.llama_proj = nn.Linear( |
| self.visual_hidden_size, self.llama_model.config.hidden_size |
| ) |
|
|
| self.max_tgt_len = max_tgt_len |
| self.device = torch.cuda.current_device() if torch.cuda.is_available() else torch.device('cpu') |
|
|
| def encode_video(self, video_paths): |
| inputs = {ModalityType.VISION: data.load_and_transform_video_data(video_paths, self.device)} |
| |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} |
| with torch.no_grad(): |
| embeddings = self.visual_encoder(inputs) |
| video_embeds = embeddings[ModalityType.VISION] |
| inputs_llama = self.llama_proj(video_embeds).unsqueeze(1) |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) |
| return inputs_llama, atts_llama |
|
|
| def encode_audio(self, audio_paths): |
| inputs = {ModalityType.AUDIO: data.load_and_transform_audio_data(audio_paths, self.device)} |
| |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} |
| with torch.no_grad(): |
| embeddings = self.visual_encoder(inputs) |
| audio_embeds = embeddings[ModalityType.AUDIO] |
| inputs_llama = self.llama_proj(audio_embeds).unsqueeze(1) |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) |
| return inputs_llama, atts_llama |
|
|
| def encode_thermal(self, thermal_paths): |
| inputs = {ModalityType.THERMAL: data.load_and_transform_thermal_data(thermal_paths, self.device)} |
| |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} |
| with torch.no_grad(): |
| embeddings = self.visual_encoder(inputs) |
| image_embeds = embeddings['thermal'] |
| inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) |
| return inputs_llama, atts_llama |
|
|
| def encode_image(self, image_paths): |
| inputs = {ModalityType.VISION: data.load_and_transform_vision_data(image_paths, self.device)} |
| |
| inputs = {key: inputs[key].to(self.llama_model.dtype) for key in inputs} |
| with torch.no_grad(): |
| embeddings = self.visual_encoder(inputs) |
| image_embeds = embeddings['vision'] |
| inputs_llama = self.llama_proj(image_embeds).unsqueeze(1) |
| atts_llama = torch.ones(inputs_llama.size()[:-1], dtype=torch.long).to(self.device) |
| return inputs_llama, atts_llama |
|
|
| def prompt_wrap(self, img_embeds, input_ids, target_ids, attention_mask): |
| ''' |
| input_ids, target_ids, attention_mask: bsz x s2 |
| ''' |
| input_ids = input_ids.to(self.device) |
| target_ids = target_ids.to(self.device) |
| attention_mask = attention_mask.to(self.device) |
|
|
| batch_size = img_embeds.shape[0] |
| p_before = PROMPT_START |
| p_before_tokens = self.llama_tokenizer(p_before, |
| return_tensors="pt", add_special_tokens=False).to(self.device) |
| |
| p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) |
| p_after_embeds = self.llama_model.model.model.embed_tokens(input_ids).expand(batch_size, -1, -1) |
| bos = torch.ones([batch_size, 1], |
| dtype=p_before_tokens.input_ids.dtype, |
| device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id |
| bos_embeds = self.llama_model.model.model.embed_tokens(bos) |
| inputs_embeds = torch.cat([bos_embeds, p_before_embeds, img_embeds, p_after_embeds], dim=1) |
|
|
| |
| empty_targets = ( |
| torch.ones([batch_size, 1+p_before_embeds.size()[1]+1], |
| dtype=torch.long).to(self.device).fill_(-100) |
| ) |
| targets = torch.cat([empty_targets, target_ids], dim=1) |
| assert inputs_embeds.size()[1] == targets.size()[1] |
|
|
| atts_prefix = torch.ones([batch_size, 1+p_before_embeds.size()[1]+1], dtype=torch.long).to(self.device) |
| attention_mask = torch.cat([atts_prefix, attention_mask], dim=1) |
| assert attention_mask.size() == targets.size() |
| return inputs_embeds, targets, attention_mask |
|
|
| def forward(self, inputs): |
| image_paths = inputs['image_paths'] |
| img_embeds, _ = self.encode_image(image_paths) |
|
|
| output_texts = inputs['output_texts'] |
| input_ids, target_ids, attention_mask = process_batch_instance(self.llama_tokenizer, output_texts, self.max_tgt_len) |
| inputs_embeds, targets, attention_mask = self.prompt_wrap(img_embeds, input_ids, target_ids, attention_mask) |
|
|
| outputs = self.llama_model( |
| inputs_embeds=inputs_embeds, |
| attention_mask=attention_mask, |
| return_dict=True, |
| labels=targets, |
| ) |
| loss = outputs.loss |
| |
| chosen_tokens = torch.max(outputs.logits, dim=-1)[1][:, 1:-1] |
| labels = targets[:, 2:] |
| gen_acc = (chosen_tokens.reshape(-1) == labels.reshape(-1)).to(torch.long) |
| valid_mask = (labels != -100).reshape(-1) |
| valid_tokens = gen_acc & valid_mask |
| gen_acc = valid_tokens.sum().item() / valid_mask.sum().item() |
| return loss, gen_acc |
|
|
| def extract_multimodal_feature(self, inputs): |
| features = [] |
| if inputs['image_paths']: |
| image_embeds, _ = self.encode_image(inputs['image_paths']) |
| features.append(image_embeds) |
| if inputs['audio_paths']: |
| audio_embeds, _ = self.encode_audio(inputs['audio_paths']) |
| features.append(audio_embeds) |
| if inputs['video_paths']: |
| video_embeds, _ = self.encode_video(inputs['video_paths']) |
| features.append(video_embeds) |
| if inputs['thermal_paths']: |
| thermal_embeds, _ = self.encode_thermal(inputs['thermal_paths']) |
| features.append(thermal_embeds) |
|
|
| feature_embeds = torch.cat(features).sum(dim=0).unsqueeze(0) |
| return feature_embeds |
|
|
| def prepare_generation_embedding(self, inputs): |
| prompt = inputs['prompt'] |
| if len(inputs['modality_embeds']) == 1: |
| feature_embeds = inputs['modality_embeds'][0] |
| else: |
| feature_embeds = self.extract_multimodal_feature(inputs) |
| inputs['modality_embeds'].append(feature_embeds) |
|
|
| batch_size = feature_embeds.shape[0] |
| p_before = PROMPT_START |
| p_before_tokens = self.llama_tokenizer(p_before, |
| return_tensors="pt", add_special_tokens=False).to(self.device) |
| p_before_embeds = self.llama_model.model.model.embed_tokens(p_before_tokens.input_ids).expand(batch_size, -1, -1) |
| text = '</Img> ' + prompt + '\n### Assistant:' |
| p_after_tokens = self.llama_tokenizer(text, add_special_tokens=False, return_tensors='pt').to(self.device) |
| p_after_embeds = self.llama_model.model.model.embed_tokens(p_after_tokens.input_ids).expand(batch_size, -1, -1) |
| bos = torch.ones([batch_size, 1], |
| dtype=p_before_tokens.input_ids.dtype, |
| device=p_before_tokens.input_ids.device) * self.llama_tokenizer.bos_token_id |
| bos_embeds = self.llama_model.model.model.embed_tokens(bos) |
| inputs_embeds = torch.cat([bos_embeds, p_before_embeds, feature_embeds, p_after_embeds], dim=1) |
| return inputs_embeds |
|
|
| def generate(self, inputs): |
| ''' |
| inputs = { |
| 'image_paths': optional, |
| 'audio_paths': optional |
| 'video_paths': optional |
| 'thermal_paths': optional |
| 'mode': generation mode, |
| 'prompt': human input prompt, |
| 'max_tgt_len': generation length, |
| 'top_p': top_p, |
| 'temperature': temperature |
| 'modality_embeds': None or torch.tensor |
| 'modality_cache': save the image cache |
| } |
| ''' |
| input_embeds = self.prepare_generation_embedding(inputs) |
| stopping_criteria = StoppingCriteriaList([StoppingCriteriaSub(stops=[2277], encounters=1)]) |
| outputs = self.llama_model.generate( |
| inputs_embeds=input_embeds, |
| max_new_tokens=inputs['max_tgt_len'], |
| top_p=inputs['top_p'], |
| temperature=inputs['temperature'], |
| do_sample=True, |
| use_cache=True, |
| stopping_criteria=stopping_criteria, |
| ) |
| output_text = self.llama_tokenizer.decode(outputs[0][:-2], skip_special_tokens=True) |
| return output_text |
|
|
|
|