Spaces:
Sleeping
Sleeping
| import torch | |
| import torch.nn as nn | |
| from torch.nn.functional import cross_entropy | |
| from transformers import CLIPVisionModel, AutoModelForCausalLM, BitsAndBytesConfig | |
| from peft import LoraConfig | |
| from tqdm import tqdm | |
| import os, peft | |
| class CustomClipPhi2(nn.Module): | |
| def __init__(self,tokenizer, phi2_model_name, clip_model_name, clip_embed=768, phi_embed=2560): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| # These two models are not finetuned | |
| # pretrained Microsoft phi2 model | |
| self.phi2_model = AutoModelForCausalLM.from_pretrained(phi2_model_name,torch_dtype=torch.float32, trust_remote_code=True) | |
| # pretrained OpenAI clip model | |
| self.clip_model = CLIPVisionModel.from_pretrained(clip_model_name) | |
| self.EOS_TOKEN_ID = self.tokenizer.eos_token_id # 50256 | |
| self.IMAGE_TOKEN_ID = 23903 # token for Comments | |
| self.clip_embed = clip_embed | |
| self.phi_embed = phi_embed | |
| # projection layers | |
| # Trainable projection layer | |
| self.projection_layer = torch.nn.Linear(clip_embed, phi_embed) | |
| # Freeze Weights | |
| for models in [self.phi2_model, self.clip_model]: | |
| for param in models.parameters(): | |
| param.requires_grad_(False) | |
| # load checkpoint weights | |
| if os.path.exists('./ckpts/model_phase1.pth'): | |
| self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location='cpu')) | |
| print("Loaded checkpoint weights for projection layer") | |
| else: | |
| print("No checkpoint weights for projection layer") | |
| print("Initializing projection layer with random weights") | |
| self.projection_layer.weight.data.normal_(mean=0.0, std=0.02) | |
| self.projection_layer.bias.data.zero_() | |
| def generate(self, images, tokenizer, config): | |
| clip_outputs = self.clip_model(**images) | |
| # remove cls token | |
| images = clip_outputs.last_hidden_state[:, 1:, :] | |
| image_embeddings = self.projection_layer(images).to(torch.float16) | |
| batch_size = images.size()[0] | |
| predicted_caption = torch.full((batch_size, config.get("max_tokens")), self.EOS_TOKEN_ID, dtype=torch.long, device=config.get('device')) | |
| img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1) | |
| img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device)) | |
| combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1) | |
| for pos in range(config.get("max_tokens") - 1): | |
| model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits'] | |
| predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1) | |
| predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) | |
| predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu') | |
| next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token) | |
| combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1) | |
| return predicted_caption | |
| def forward(self, images, target_captions): | |
| batch_size = target_captions.size()[0] | |
| target_length = target_captions.size()[1] | |
| print("---", target_length) | |
| # clip model output for image | |
| clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36 | |
| images = clip_outputs.last_hidden_state[:, 1:, :] # remove CLS token | |
| # projection layer | |
| image_embeddings = self.projection_layer(images).to(torch.float16) | |
| # add comment token from phi2 | |
| img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1) | |
| img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device)) | |
| combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1) # 4,49,2560 | |
| del clip_outputs | |
| del image_embeddings | |
| # for loss | |
| loss = 0 | |
| for pos in range(target_length - 1): | |
| model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits'] | |
| predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1) | |
| pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), target_captions[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1) | |
| loss += pos_loss | |
| predicted_word_token = torch.argmax(predicted_word_token_logits, dim=-1) | |
| next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token) | |
| combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1) | |
| loss = loss / target_length | |
| # Delete variables to free up memory | |
| del combined_embeds | |
| del model_output_logits | |
| torch.cuda.empty_cache() | |
| return loss | |
| def show_results_for_samples_phase1(model, val_dataloader, tokenizer, config, num_samples = 2): | |
| model.eval() | |
| with torch.no_grad(): | |
| for i in range(num_samples): | |
| for images, target_captions in val_dataloader: | |
| images = {'pixel_values': images.to(config.get('device'))} | |
| target_captions = target_captions.to(config.get('device')) | |
| target_captions_decoded = tokenizer.batch_decode(target_captions, ignore_index = tokenizer.eos_token_id) | |
| predicted_captions = model.generate(images, tokenizer, config) | |
| predicted_captions_decoded = tokenizer.batch_decode(predicted_captions,ignore_index = tokenizer.eos_token_id) | |
| for idx, pc in enumerate(predicted_captions_decoded): | |
| print(f"{idx} - Target captions: {target_captions_decoded[idx]} \n {'---------------------'*10} \n Predicted_captions:{pc} ") | |
| break | |
| def validate_model_phase1(model, val_dataloader, tokenizer, config): | |
| model.eval() | |
| total_loss = 0 | |
| with torch.no_grad(): | |
| try: | |
| for images, target_captions in tqdm(val_dataloader): | |
| images = {'pixel_values': images.to(config.get('device'))} | |
| target_captions = target_captions.to(config.get('device')) | |
| loss = model(images, target_captions) | |
| total_loss+=loss.item() | |
| print(f"Validation Loss: {total_loss/len(val_dataloader)}") | |
| except Exception as e: | |
| pass | |
| model.train() | |
| def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer, config): | |
| model.train() | |
| pbar = tqdm(train_loader) | |
| for epoch in range(1, config.get("epochs")): | |
| print(f"Epoch: {epoch}") | |
| torch.cuda.empty_cache() | |
| step = 1 | |
| try: | |
| for idx, (images, target_captions) in enumerate(pbar): | |
| try: | |
| if target_captions.shape[1] >= config.get("max_tokens"): | |
| # print(f"Skipping batch {idx} due to long caption") | |
| continue | |
| images = {'pixel_values': images.to(config.get('device'))} | |
| target_captions = target_captions.to(config.get('device')) | |
| optimizer.zero_grad() | |
| loss = model(images, target_captions) | |
| loss.backward() | |
| optimizer.step() | |
| pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}") | |
| torch.cuda.empty_cache() | |
| step+=1 | |
| if (step%1000==0): | |
| torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth') | |
| except Exception as e: | |
| print(e) | |
| continue | |
| # # save model | |
| # if ((epoch % 2) == 0): | |
| # Only save last checkpoint | |
| validate_model_phase1(model, val_dataloader, tokenizer, config) | |
| show_results_for_samples_phase1(model, val_dataloader, tokenizer, config) | |
| torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth') | |
| except Exception as e: | |
| print(e) | |
| continue | |
| ######################################## Phase 2 ######################################### | |
| class MainQLoraModel(nn.Module): | |
| def __init__(self, tokenizer, config): | |
| super().__init__() | |
| self.tokenizer = tokenizer | |
| self.config = config | |
| self.clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name")) | |
| bnb_config = BitsAndBytesConfig( | |
| load_in_4bit=True, | |
| bnb_4bit_quant_type="nf4", | |
| bnb_4bit_compute_dtype=torch.float16, | |
| ) | |
| phi2_model = AutoModelForCausalLM.from_pretrained( | |
| config.get("phi2_model_name"), | |
| quantization_config=bnb_config, | |
| trust_remote_code=True | |
| ) | |
| phi2_model.config.use_cache = False | |
| ## 4 - LORA config | |
| lora_alpha = 16 | |
| lora_dropout = 0.1 | |
| lora_r = 64 | |
| peft_config = LoraConfig( | |
| lora_alpha = lora_alpha, | |
| lora_dropout = lora_dropout, | |
| r = lora_r, | |
| bias="none", | |
| task_type="CAUSAL_LM", | |
| target_modules=[ | |
| "q_proj", | |
| "k_proj", | |
| "v_proj", | |
| "dense", | |
| "fc1", | |
| "fc2" | |
| ] | |
| ) | |
| self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device")) | |
| self.EOS_TOKEN_ID = self.tokenizer.eos_token_id | |
| self.clip_embed = config.get("clip_embed") | |
| self.phi_embed = config.get("phi_embed") | |
| # projection layers | |
| # Trainable projection layer | |
| self.projection_layer = torch.nn.Linear(self.clip_embed, self.phi_embed) | |
| # Freeze Weights | |
| for models in [self.clip_model]: | |
| for param in models.parameters(): | |
| param.requires_grad_(False) | |
| # load checkpoint weights | |
| if os.path.exists('./ckpts/model_phase2.pth'): | |
| self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device"))) | |
| self.phi2_model.from_pretrained(self.phi2_model,'./ckpts/Qlora_adaptor') | |
| print("Loaded checkpoint weights for projection layer") | |
| else: | |
| # Load weights from phase 1 | |
| self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device"))) | |
| def generate(self, tokenizer, config, images = None, ques = None, max_tokens = 100): | |
| batch_size = 1 | |
| predicted_caption = torch.full((batch_size, max_tokens), self.EOS_TOKEN_ID, dtype=torch.long, device=self.config.get('device')) | |
| start_iq = self.tokenizer.encode("<iQ>") | |
| end_iq = self.tokenizer.encode("</iQ>") | |
| start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1) | |
| end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1) | |
| start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device"))) | |
| end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device"))) | |
| questions_embed = self.phi2_model.model.model.embed_tokens(ques) | |
| if images is not None: | |
| clip_outputs = self.clip_model(**images) | |
| # remove cls token | |
| images = clip_outputs.last_hidden_state[:, 1:, :] | |
| image_embeddings = self.projection_layer(images).to(torch.float16) | |
| combined_embeds = torch.cat([start_iq_embeds, image_embeddings, questions_embed, end_iq_embeds], dim=1) | |
| else: | |
| combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds], dim=1) | |
| for pos in range(max_tokens - 1): | |
| model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits'] | |
| predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1) | |
| predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1) | |
| predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu') | |
| next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token) | |
| combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1) | |
| return predicted_caption | |
| def forward(self, images, ques, ans): | |
| batch_size = ques.size()[0] | |
| questions = ques.to(self.config.get("device")) | |
| answers = ans.to(self.config.get("device")) | |
| target_length = ans.size()[1] | |
| start_iq = self.tokenizer.encode("<iQ>") | |
| end_iq = self.tokenizer.encode("</iQ>") | |
| start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1) | |
| end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1) | |
| start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device"))) | |
| end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device"))) | |
| questions_embed = self.phi2_model.model.model.embed_tokens(questions) | |
| answers_embed = self.phi2_model.model.model.embed_tokens(answers) | |
| are_all_zeros = torch.all(images == 0).item() | |
| if are_all_zeros: | |
| combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1) | |
| else: | |
| images = {'pixel_values': images.to(self.config.get("device"))} | |
| clip_outputs = self.clip_model(**images) | |
| images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token | |
| # projection | |
| image_embeds = self.projection_layer(images_embeds).to(torch.float16) | |
| combined_embeds = torch.cat([start_iq_embeds, image_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1) | |
| model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits'] | |
| # # for loss | |
| loss = 0 | |
| for pos in range(target_length - 1): | |
| predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1) | |
| pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), answers[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1) | |
| loss += pos_loss | |
| loss = loss / target_length | |
| # Delete variables to free up memory | |
| del combined_embeds | |
| del model_output_logits | |
| torch.cuda.empty_cache() | |
| return loss | |
| def validate_model_phase2(model, val_dataloader, tokenizer, config): | |
| model.eval() | |
| total_loss = 0 | |
| with torch.no_grad(): | |
| # try: | |
| for images, ques, ans in tqdm(val_dataloader): | |
| loss = model(images, ques, ans) | |
| total_loss+=loss.item() | |
| print(f"Validation Loss: {total_loss/len(val_dataloader)}") | |
| # except Exception as e: | |
| # pass | |
| model.train() | |
| def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config): | |
| phi2_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.phi2_model.parameters()), lr=1e-5) | |
| proj_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.projection_layer.parameters()), lr=1e-5) | |
| model.phi2_model.train() | |
| model.projection_layer.train() | |
| pbar = tqdm(train_loader) | |
| for epoch in range(1, config.get("epochs")): | |
| print(f"Epoch: {epoch}") | |
| torch.cuda.empty_cache() | |
| step = 1 | |
| try: | |
| for idx, (images, ques, ans) in enumerate(pbar): | |
| try: | |
| phi2_optim.zero_grad() | |
| proj_optim.zero_grad() | |
| loss = model(images, ques, ans) | |
| loss.backward() | |
| phi2_optim.step() | |
| proj_optim.step() | |
| pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}") | |
| torch.cuda.empty_cache() | |
| step+=1 | |
| if (step%1000==0): | |
| torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth') | |
| model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True) | |
| except Exception as e: | |
| print("in frp",e) | |
| continue | |
| validate_model_phase2(model, val_dataloader, tokenizer, config) | |
| torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth') | |
| model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True) | |
| except Exception as e: | |
| print(e) | |
| continue |