Vasudevakrishna commited on
Commit
178553c
·
verified ·
1 Parent(s): 80b4630

Upload 5 files

Browse files
Files changed (5) hide show
  1. configs.py +36 -0
  2. dataset.py +204 -0
  3. get_coco.py +41 -0
  4. main.py +41 -0
  5. model.py +336 -0
configs.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import multiprocessing
3
+
4
+ def get_config_phase1():
5
+ return {
6
+ "data_dir": "./data",
7
+ "clip_model_name": "openai/clip-vit-base-patch16",
8
+ "phi2_model_name": "microsoft/phi-2",
9
+ "train_batch_size": 2,
10
+ "val_batch_size": 1,
11
+ "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
12
+ "epochs": 2,
13
+ "max_tokens": 20,
14
+ "clip_embed": 768,
15
+ "phi_embed": 2560,
16
+ "num_workers": 32,
17
+ "ckpts": "./ckpts"
18
+ }
19
+
20
+ def get_config_phase2():
21
+ return {
22
+ "i150k_json": "./data/llava_instruct_150k.json",
23
+ "QA_datasetName": "OpenAssistant/oasst1",
24
+ "clip_model_name": "openai/clip-vit-base-patch16",
25
+ "phi2_model_name": "microsoft/phi-2",
26
+ "train_batch_size": 1,
27
+ "val_batch_size": 1,
28
+ "device": torch.device("cuda" if torch.cuda.is_available() else "cpu"),
29
+ "epochs": 2,
30
+ "max_tokens": 20,
31
+ "clip_embed": 768,
32
+ "phi_embed": 2560,
33
+ "num_workers": 1,
34
+ "ckpts": "./ckpts",
35
+ "vocab_size": 51200
36
+ }
dataset.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ from PIL import Image
5
+ from torch.utils.data import Dataset
6
+ from transformers import AutoProcessor
7
+ from torch.utils.data import DataLoader
8
+ import pickle
9
+ import requests
10
+ from datasets import Dataset, load_dataset
11
+ import pandas as pd
12
+ import numpy as np
13
+
14
+
15
+ class ClipDataset(Dataset):
16
+ '''ClipDataset class for loading the CLIP dataset'''
17
+ def __init__(self, coco_data, model_name, tokenizer):
18
+
19
+ self.tokenizer = tokenizer
20
+ self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
21
+ self.caption_dataset = coco_data
22
+
23
+ def __len__(self):
24
+ #Return the length of the dataset
25
+ return len(self.caption_dataset)
26
+
27
+ def __getitem__(self, idx):
28
+ #Get the image url and caption
29
+ img_url = self.caption_dataset[idx]["image_url"]
30
+ caption = self.caption_dataset[idx]["caption"]
31
+
32
+ #Get the image and caption embeddings
33
+ image = Image.open(requests.get(img_url,stream=True).raw)
34
+ width, height = image.size
35
+ new_width = 224
36
+ new_height = new_width * height // width
37
+ new_height = 224
38
+ new_width = new_height * width // height
39
+ image = image.resize((new_width, new_height), Image.LANCZOS)
40
+ image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values']
41
+ image_sqeezed = image_processed.squeeze(0)
42
+ tokenized_caption = self.tokenizer(caption, return_tensors="pt", return_attention_mask=False)
43
+ tokenized_caption_ids = tokenized_caption['input_ids'].squeeze(0)
44
+ return(image_sqeezed , tokenized_caption_ids)
45
+
46
+
47
+ def collate_fn_phase1(batch):
48
+ #Unzip the batch
49
+ image_embeddings, captions = zip(*batch)
50
+ #Stack the image embeddings
51
+ image_embeddings_stacked = torch.stack(image_embeddings, dim=0)
52
+ #Pad the captions, padded value is the <eos> token
53
+ captions_padded = torch.nn.utils.rnn.pad_sequence(captions, batch_first=True, padding_value=50256)
54
+ #Return the stacked image embeddings and padded captions
55
+ return (image_embeddings_stacked, captions_padded)
56
+
57
+
58
+ def get_data_loaders_phase1(data_dir, clip_model_name, tokenizer, train_batch_size, val_batch_size, num_workers):
59
+ # Load the data
60
+ with open(os.path.join(data_dir, 'coco_train.pkl'), 'rb') as fp:
61
+ train_pkl = pickle.load(fp)
62
+ with open(os.path.join(data_dir, "coco_val.pkl"), "rb") as fp:
63
+ val_pkl = pickle.load(fp)
64
+ # train data loaders
65
+ train_dataloader = DataLoader(ClipDataset(train_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=train_batch_size, num_workers = num_workers, shuffle=True, pin_memory=True)
66
+
67
+ # val data loaders
68
+ val_dataloader = DataLoader(ClipDataset(val_pkl, clip_model_name, tokenizer), collate_fn=collate_fn_phase1, batch_size=val_batch_size, num_workers = num_workers, shuffle=False, pin_memory=True)
69
+ return train_dataloader, val_dataloader
70
+
71
+ ##################################### Phase 2 #########################################
72
+
73
+
74
+ class ClipDatasetPhase2(Dataset):
75
+ '''ClipDataset class for loading the CLIP dataset'''
76
+ def __init__(self, data_frame, model_name, tokenizer):
77
+
78
+ self.tokenizer = tokenizer
79
+ self.processor = AutoProcessor.from_pretrained(model_name, trust_remote_code=True)
80
+ self.df = data_frame
81
+
82
+ def __len__(self):
83
+ #Return the length of the dataset
84
+ return len(self.df)
85
+
86
+ def __getitem__(self, idx):
87
+ #Get the image url and QAs
88
+ img_url = self.df.ImageUrl[idx[0]]
89
+ que = self.df.Question[idx[0]]
90
+ ans = self.df.Answer[idx[0]]
91
+
92
+ print("img_url", img_url)
93
+ print("que", que)
94
+ print("ans", ans)
95
+
96
+ #Get the image and caption embeddings
97
+ if img_url is None:
98
+ print("img_url is None")
99
+ image_sqeezed = None
100
+ else:
101
+ image = Image.open(requests.get(img_url,stream=True).raw)
102
+ width, height = image.size
103
+ new_width = 224
104
+ new_height = new_width * height // width
105
+ new_height = 224
106
+ new_width = new_height * width // height
107
+ image = image.resize((new_width, new_height), Image.LANCZOS)
108
+ image_processed = self.processor(images=image, return_tensors="pt") ['pixel_values']
109
+ image_sqeezed = image_processed.squeeze(0)
110
+ que_ids = self.tokenizer(que, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
111
+ ans_ids = self.tokenizer(ans, return_tensors="pt", return_attention_mask=False)['input_ids'].squeeze(0)
112
+ return(image_sqeezed , que_ids, ans_ids)
113
+
114
+
115
+ def collate_fn_phase2(batch):
116
+ #Unzip the batch
117
+ image_embeddings, ques, ans = zip(*batch)
118
+ #Stack the image embeddings
119
+ if image_embeddings[0] is None:
120
+ image_embeddings_stacked = None
121
+ else:
122
+ image_embeddings_stacked = torch.stack(image_embeddings, dim=0)
123
+ #Pad the QAs, padded value is the <eos> token
124
+ ques_padded = torch.nn.utils.rnn.pad_sequence(ques, batch_first=True, padding_value=50256)
125
+ ans_padded = torch.nn.utils.rnn.pad_sequence(ans, batch_first=True, padding_value=50256)
126
+ #Return the stacked image embeddings and padded QAs
127
+ return (image_embeddings_stacked, ques_padded, ans_padded)
128
+
129
+
130
+ def prep_data(df):
131
+ df_assistant = df[(df.role == "assistant") & (df["rank"] == 0.0)].copy()
132
+ df_prompter = df[(df.role == "prompter")].copy()
133
+ df_prompter = df_prompter.set_index("message_id")
134
+ df_assistant["Answer"] = df_assistant["text"].values
135
+
136
+ inputs = []
137
+ for _, row in df_assistant.iterrows():
138
+ input = df_prompter.loc[row.parent_id]
139
+ inputs.append(input.text)
140
+
141
+ df_assistant["Question"] = inputs
142
+ df_assistant["ImageUrl"] = None
143
+
144
+ df_assistant = df_assistant[df_assistant.lang == "en"]
145
+
146
+ df_assistant = df_assistant[
147
+ ["ImageUrl","Question", "Answer", "message_id"]
148
+ ].rename(columns={"message_id": "Ids"})
149
+
150
+ return df_assistant
151
+
152
+
153
+ def get_i150_df(config):
154
+ with open(config.get("i150k_json"), "r") as fp:
155
+ i150k_json_read = json.load(fp)
156
+ max_tokens = 100
157
+ image_urls = []
158
+ ques_list = []
159
+ ans_list = []
160
+ id_list = []
161
+ for idx, data in enumerate(i150k_json_read):
162
+ image = data['image']
163
+ image_url = 'http://images.cocodataset.org/train2017/' + image
164
+ id_ = data["id"]
165
+ iterator = iter(data['conversations'])
166
+ for i in iterator:
167
+ ques = i
168
+ ans = next(iterator)
169
+ if (len(ques["value"])>100 or len(ans["value"])>max_tokens):
170
+ continue
171
+ if ques["from"] == "human" and ans["from"] == "gpt":
172
+ image_urls.append(image_url)
173
+ ques_list.append(ques["value"].replace("<image>\n","").replace("<image>",""))
174
+ ans_list.append(ans["value"])
175
+ id_list.append(id_)
176
+ df_i150k = pd.DataFrame(list(zip(image_urls, ques_list, ans_list, id_list)),
177
+ columns =["ImageUrl", "Question", "Answer", "Ids"])
178
+ msk = np.random.rand(len(df_i150k)) < 0.96
179
+
180
+ train_df = df_i150k[msk]
181
+ test_df = df_i150k[~msk]
182
+ return train_df, test_df
183
+
184
+
185
+ def get_oas_df(config):
186
+ train_ds, val_ds = load_dataset(config.get("QA_datasetName"), split=["train", "validation"])
187
+ train_df = prep_data(train_ds.to_pandas())
188
+ test_df = prep_data(val_ds.to_pandas())
189
+ return train_df, test_df
190
+
191
+
192
+ def get_data_loaders_phase2(tokenizer, config):
193
+
194
+ train_i150k, test_i150k = get_i150_df(config)
195
+ train_oas, test_oas = get_oas_df(config)
196
+
197
+ train_df = pd.concat([train_i150k, train_oas]).reset_index(drop=True)
198
+ val_df = pd.concat([test_i150k, test_oas]).reset_index(drop=True)
199
+ # train data loaders
200
+ train_dataloader = DataLoader(ClipDatasetPhase2(train_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("train_batch_size"), num_workers = config.get("num_workers"), shuffle=True, pin_memory=True)
201
+
202
+ # val data loaders
203
+ val_dataloader = DataLoader(ClipDatasetPhase2(val_df, config.get("clip_model_name"), tokenizer), collate_fn=collate_fn_phase2, batch_size=config.get("val_batch_size"), num_workers = config.get("num_workers"), shuffle=False, pin_memory=True)
204
+ return train_dataloader, val_dataloader
get_coco.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, shutil, json
2
+ import pickle, argparse
3
+
4
+ """Unzip the data and and save it as a pickle file."""
5
+
6
+ def make_pkl(data_dir, dataset_json, train_flag=False):
7
+ coco_data_list = []
8
+ for i, data in enumerate(dataset_json['annotations']):
9
+ image_id = data['image_id']
10
+ caption = data['caption']
11
+ for img in dataset_json['images']:
12
+ if img['id'] == image_id:
13
+ image_url = img['coco_url']
14
+ file_name = img['file_name']
15
+ break
16
+ coco_data_list.append({'image_id': image_id,'image_url': image_url, 'file_name': file_name, 'caption': caption})
17
+ if train_flag:
18
+ with open(os.path.join(data_dir, f'coco_train.pkl'), 'wb') as f:
19
+ pickle.dump(coco_data_list, f)
20
+ else:
21
+ with open(os.path.join(data_dir, f'coco_val.pkl'), 'wb') as f:
22
+ pickle.dump(coco_data_list, f)
23
+
24
+
25
+ def main(coco_path, data_dir):
26
+ coco_dir = os.path.dirname(coco_path)
27
+ # shutil.unpack_archive(coco_path, coco_dir)
28
+ with open(os.path.join(coco_dir, 'annotations/captions_train2017.json')) as f:
29
+ coco_train_dataset = json.load(f)
30
+ with open(os.path.join(coco_dir, 'annotations/captions_val2017.json')) as f:
31
+ coco_val_dataset = json.load(f)
32
+ make_pkl(data_dir, coco_train_dataset, train_flag=True)
33
+ # make_pkl(data_dir, coco_val_dataset)
34
+
35
+
36
+ if __name__ == '__main__':
37
+ parser = argparse.ArgumentParser()
38
+ parser.add_argument('--coco_path', type=str, default='coco.zip')
39
+ parser.add_argument('--data_dir', type=str, default='data')
40
+ args = parser.parse_args()
41
+ main(args.coco_path, args.data_dir)
main.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from dataset import get_data_loaders_phase1, get_data_loaders_phase2
3
+ from transformers import AutoTokenizer
4
+ from model import CustomClipPhi2, MainQLoraModel, train_model_phase1, train_model_phase2
5
+ from configs import get_config_phase1, get_config_phase2
6
+
7
+ def phase_1():
8
+ # get config
9
+ config = get_config_phase1()
10
+ # tokenizer
11
+ tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
12
+
13
+ # data loaders
14
+ train_dataloader, val_dataloader = get_data_loaders_phase1(config.get("data_dir"), config.get("clip_model_name"), tokenizer, config.get("train_batch_size"), config.get("val_batch_size"), config.get("num_workers"))
15
+
16
+ llmModel = CustomClipPhi2(tokenizer, config.get("phi2_model_name"), config.get("clip_model_name"), clip_embed=768, phi_embed=2560).to(config.get("device"))
17
+ print(llmModel)
18
+ # optimizer
19
+ optimizer = torch.optim.Adam(filter(lambda p: p.requires_grad, llmModel.parameters()), lr=1e-3)
20
+ # train model
21
+ train_model_phase1(llmModel, train_dataloader, val_dataloader, optimizer, tokenizer, config)
22
+
23
+
24
+ def phase_2():
25
+ # get config
26
+ config = get_config_phase2()
27
+ # tokenizer
28
+ tokenizer = AutoTokenizer.from_pretrained(config.get("phi2_model_name"), trust_remote_code=True)
29
+
30
+ # data loaders
31
+ train_dataloader, val_dataloader = get_data_loaders_phase2(tokenizer, config)
32
+
33
+ llmModel = MainQLoraModel(tokenizer, config).to(config.get("device"))
34
+ print(llmModel)
35
+ # train model
36
+ train_model_phase2(llmModel, train_dataloader, val_dataloader, tokenizer, config)
37
+
38
+ if __name__ == "__main__":
39
+ torch.set_float32_matmul_precision('medium')
40
+ phase_1()
41
+ # phase_2()
model.py ADDED
@@ -0,0 +1,336 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.nn.functional import cross_entropy
4
+ from transformers import CLIPVisionModel, AutoModelForCausalLM, BitsAndBytesConfig
5
+ from peft import LoraConfig
6
+ from tqdm import tqdm
7
+ import os, peft
8
+
9
+
10
+ class CustomClipPhi2(nn.Module):
11
+ def __init__(self,tokenizer, phi2_model_name, clip_model_name, clip_embed=768, phi_embed=2560):
12
+ super().__init__()
13
+
14
+ self.tokenizer = tokenizer
15
+ # These two models are not finetuned
16
+ # pretrained Microsoft phi2 model
17
+ self.phi2_model = AutoModelForCausalLM.from_pretrained(phi2_model_name,torch_dtype=torch.float32, trust_remote_code=True)
18
+ # pretrained OpenAI clip model
19
+ self.clip_model = CLIPVisionModel.from_pretrained(clip_model_name)
20
+
21
+ self.EOS_TOKEN_ID = self.tokenizer.eos_token_id # 50256
22
+ self.IMAGE_TOKEN_ID = 23903 # token for Comments
23
+ self.clip_embed = clip_embed
24
+ self.phi_embed = phi_embed
25
+
26
+ # projection layers
27
+ # Trainable projection layer
28
+ self.projection_layer = torch.nn.Linear(clip_embed, phi_embed)
29
+
30
+ # Freeze Weights
31
+ for models in [self.phi2_model, self.clip_model]:
32
+ for param in models.parameters():
33
+ param.requires_grad_(False)
34
+
35
+ # load checkpoint weights
36
+ if os.path.exists('./ckpts/model_phase1.pth'):
37
+ self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location='cpu'))
38
+ print("Loaded checkpoint weights for projection layer")
39
+ else:
40
+ print("No checkpoint weights for projection layer")
41
+ print("Initializing projection layer with random weights")
42
+ self.projection_layer.weight.data.normal_(mean=0.0, std=0.02)
43
+ self.projection_layer.bias.data.zero_()
44
+
45
+
46
+ def generate(self, images, tokenizer, config):
47
+ clip_outputs = self.clip_model(**images)
48
+ # remove cls token
49
+ images = clip_outputs.last_hidden_state[:, 1:, :]
50
+ image_embeddings = self.projection_layer(images).to(torch.float16)
51
+
52
+ batch_size = images.size()[0]
53
+ predicted_caption = torch.full((batch_size, config.get("max_tokens")), self.EOS_TOKEN_ID, dtype=torch.long, device=config.get('device'))
54
+ img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
55
+ img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
56
+ combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1)
57
+
58
+ for pos in range(config.get("max_tokens") - 1):
59
+ model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
60
+ predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
61
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
62
+ predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
63
+ next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
64
+ combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
65
+ return predicted_caption
66
+
67
+
68
+ def forward(self, images, target_captions):
69
+
70
+ batch_size = target_captions.size()[0]
71
+ target_length = target_captions.size()[1]
72
+
73
+ # clip model output for image
74
+ clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36
75
+ images = clip_outputs.last_hidden_state[:, 1:, :] # remove CLS token
76
+
77
+ # projection layer
78
+ image_embeddings = self.projection_layer(images).to(torch.float16)
79
+
80
+ # add comment token from phi2
81
+ img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1)
82
+ img_token_embeds = self.phi2_model.model.embed_tokens(img_token_tensor.to(image_embeddings.device))
83
+ combined_embeds = torch.cat([image_embeddings, img_token_embeds], dim=1) # 4,49,2560
84
+ del clip_outputs
85
+ del image_embeddings
86
+
87
+ # for loss
88
+ loss = 0
89
+ for pos in range(target_length - 1):
90
+
91
+ model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
92
+ predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
93
+ pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), target_captions[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
94
+ loss += pos_loss
95
+
96
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim=-1)
97
+ next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
98
+ combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
99
+ loss = loss / target_length
100
+
101
+ # Delete variables to free up memory
102
+ del combined_embeds
103
+ del model_output_logits
104
+ torch.cuda.empty_cache()
105
+
106
+ return loss
107
+
108
+
109
+ def show_results_for_samples_phase1(model, val_dataloader, tokenizer, config, num_samples = 2):
110
+ model.eval()
111
+ with torch.no_grad():
112
+ for i in range(num_samples):
113
+ for images, target_captions in val_dataloader:
114
+ images = {'pixel_values': images.to(config.get('device'))}
115
+ target_captions = target_captions.to(config.get('device'))
116
+ target_captions_decoded = tokenizer.batch_decode(target_captions, ignore_index = tokenizer.eos_token_id)
117
+ predicted_captions = model.generate(images, tokenizer, config)
118
+ predicted_captions_decoded = tokenizer.batch_decode(predicted_captions,ignore_index = tokenizer.eos_token_id)
119
+
120
+ for idx, pc in enumerate(predicted_captions_decoded):
121
+ print(f"{idx} - Target captions: {target_captions_decoded[idx]} \n {'---------------------'*10} \n Predicted_captions:{pc} ")
122
+ break
123
+
124
+
125
+ def validate_model_phase1(model, val_dataloader, tokenizer, config):
126
+ model.eval()
127
+ total_loss = 0
128
+ with torch.no_grad():
129
+ try:
130
+ for images, target_captions in tqdm(val_dataloader):
131
+ images = {'pixel_values': images.to(config.get('device'))}
132
+ target_captions = target_captions.to(config.get('device'))
133
+ loss = model(images, target_captions)
134
+ total_loss+=loss.item()
135
+ print(f"Validation Loss: {total_loss/len(val_dataloader)}")
136
+ except Exception as e:
137
+ pass
138
+ model.train()
139
+
140
+
141
+ def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer, config):
142
+ model.train()
143
+
144
+ pbar = tqdm(train_loader)
145
+ for epoch in range(1, config.get("epochs")):
146
+ print(f"Epoch: {epoch}")
147
+ torch.cuda.empty_cache()
148
+ step = 1
149
+ try:
150
+ for idx, (images, target_captions) in enumerate(pbar):
151
+ try:
152
+ if target_captions.shape[1] >= config.get("max_tokens"):
153
+ # print(f"Skipping batch {idx} due to long caption")
154
+ continue
155
+
156
+ images = {'pixel_values': images.to(config.get('device'))}
157
+ target_captions = target_captions.to(config.get('device'))
158
+
159
+ optimizer.zero_grad()
160
+ loss = model(images, target_captions)
161
+ loss.backward()
162
+ optimizer.step()
163
+ pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
164
+ torch.cuda.empty_cache()
165
+ step+=1
166
+ if (step%1000==0):
167
+ torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
168
+ except Exception as e:
169
+ continue
170
+
171
+ # # save model
172
+ # if ((epoch % 2) == 0):
173
+ # Only save last checkpoint
174
+ validate_model_phase1(model, val_dataloader, tokenizer, config)
175
+ show_results_for_samples_phase1(model, val_dataloader, tokenizer, config)
176
+ torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
177
+
178
+ except Exception as e:
179
+ continue
180
+
181
+
182
+
183
+
184
+ ######################################## Phase 2 #########################################
185
+
186
+ class MainQLoraModel(nn.Module):
187
+ def __init__(self, tokenizer, config):
188
+ super().__init__()
189
+ self.tokenizer = tokenizer
190
+ self.config = config
191
+ self.clip_model = CLIPVisionModel.from_pretrained(config.get("clip_model_name"))
192
+
193
+ bnb_config = BitsAndBytesConfig(
194
+ load_in_4bit=True,
195
+ bnb_4bit_quant_type="nf4",
196
+ bnb_4bit_compute_dtype=torch.float16,
197
+ )
198
+
199
+ phi2_model = AutoModelForCausalLM.from_pretrained(
200
+ config.get("phi2_model_name"),
201
+ quantization_config=bnb_config,
202
+ trust_remote_code=True
203
+ )
204
+ phi2_model.config.use_cache = False
205
+
206
+ ## 4 - LORA config
207
+
208
+ lora_alpha = 16
209
+ lora_dropout = 0.1
210
+ lora_r = 64
211
+
212
+ peft_config = LoraConfig(
213
+ lora_alpha = lora_alpha,
214
+ lora_dropout = lora_dropout,
215
+ r = lora_r,
216
+ bias="none",
217
+ task_type="CAUSAL_LM",
218
+ target_modules=[
219
+ "q_proj",
220
+ "k_proj",
221
+ "v_proj",
222
+ "dense",
223
+ "fc1",
224
+ "fc2"
225
+ ]
226
+ )
227
+ self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device"))
228
+
229
+ self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
230
+ self.IMAGE_TOKEN_ID = 23903 # token for Comments
231
+ self.clip_embed = config.get("clip_embed")
232
+ self.phi_embed = config.get("phi_embed")
233
+
234
+ # projection layers
235
+ # Trainable projection layer
236
+ self.projection_layer = torch.nn.Linear(self.clip_embed, self.phi_embed)
237
+
238
+ # Freeze Weights
239
+ for models in [self.clip_model]:
240
+ for param in models.parameters():
241
+ param.requires_grad_(False)
242
+
243
+ # load checkpoint weights
244
+ if os.path.exists('./ckpts/model_phase2.pth'):
245
+ self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase2.pth', map_location=config.get("device")))
246
+ self.phi2_model.from_pretrained(self.phi2_model,'./ckpts/Qlora_adaptor')
247
+ print("Loaded checkpoint weights for projection layer")
248
+ else:
249
+ # Load weights from phase 1
250
+ self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device")))
251
+
252
+
253
+ def forward(self, images, ques, ans):
254
+
255
+ batch_size = ques.size()[0]
256
+ questions = ques.to(self.config.get("device"))
257
+ answers = ans.to(self.config.get("device"))
258
+
259
+ questions_embed = peft_model.model.model.embed_tokens(questions)
260
+ if images is None:
261
+ combined_embeds = questions_embed
262
+ else:
263
+ images = {'pixel_values': images.to(self.config.get("device"))}
264
+ clip_outputs = clip_model(**images)
265
+ images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token
266
+
267
+ # projection
268
+ image_embeds = projection(images_embeds).to(torch.float16)
269
+ img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1).to(self.config.get("device"))
270
+ img_token_embeds = peft_model.model.model.embed_tokens(img_token_tensor)
271
+ combined_embeds = torch.cat([image_embeds, img_token_embeds, questions_embed], dim=1)
272
+
273
+ phi_output_logits = peft_model(inputs_embeds=combined_embeds)['logits']
274
+
275
+ if images is not None:
276
+ # remove image and image token embeddings
277
+ phi_output_logits = phi_output_logits[:,images_embeds.shape[1] + 2 : ,:]
278
+
279
+ phi_output_logits = phi_output_logits.reshape(-1, self.config.get("vocab_size"))
280
+
281
+ loss = cross_entropy(phi_output_logits, answers.contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID, label_smoothing=0.1)
282
+
283
+ return loss
284
+
285
+ def validate_model_phase2(model, val_dataloader, tokenizer, config):
286
+ model.eval()
287
+ total_loss = 0
288
+ with torch.no_grad():
289
+ try:
290
+ for images, ques, ans in tqdm(val_dataloader):
291
+ loss = model(images, ques, ans)
292
+ total_loss+=loss.item()
293
+ print(f"Validation Loss: {total_loss/len(val_dataloader)}")
294
+ except Exception as e:
295
+ pass
296
+ model.train()
297
+
298
+
299
+ def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config):
300
+ phi2_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.phi2_model.parameters()), lr=1e-5)
301
+ proj_optim = torch.optim.Adam(filter(lambda p: p.requires_grad, model.projection_layer.parameters()), lr=1e-5)
302
+ model.phi2_model.train()
303
+ model.projection_layer.train()
304
+
305
+ pbar = tqdm(train_loader)
306
+ for epoch in range(1, config.get("epochs")):
307
+ print(f"Epoch: {epoch}")
308
+ torch.cuda.empty_cache()
309
+ step = 1
310
+ try:
311
+ for idx, (images, ques, ans) in enumerate(pbar):
312
+ try:
313
+ print("hi")
314
+ phi2_optim.zero_grad()
315
+ proj_optim.zero_grad()
316
+ loss = model(images, ques, ans)
317
+ loss.backward()
318
+ phi2_optim.step()
319
+ proj_optim.step()
320
+ pbar.set_description(f"Epoch: {epoch}: Training Loss = {loss.item()}")
321
+ torch.cuda.empty_cache()
322
+ step+=1
323
+ if (step%1000==0):
324
+ torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
325
+ model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
326
+ except Exception as e:
327
+ print(e)
328
+ continue
329
+
330
+ validate_model_phase2(model, val_dataloader, tokenizer, config)
331
+ torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
332
+ model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
333
+
334
+ except Exception as e:
335
+ print(e)
336
+ continue