Vasudevakrishna commited on
Commit
7417ef5
·
verified ·
1 Parent(s): ebc4357

Update model.py

Browse files
Files changed (1) hide show
  1. model.py +70 -28
model.py CHANGED
@@ -69,6 +69,7 @@ class CustomClipPhi2(nn.Module):
69
 
70
  batch_size = target_captions.size()[0]
71
  target_length = target_captions.size()[1]
 
72
 
73
  # clip model output for image
74
  clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36
@@ -166,6 +167,7 @@ def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer
166
  if (step%1000==0):
167
  torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
168
  except Exception as e:
 
169
  continue
170
 
171
  # # save model
@@ -176,6 +178,7 @@ def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer
176
  torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
177
 
178
  except Exception as e:
 
179
  continue
180
 
181
 
@@ -227,7 +230,6 @@ class MainQLoraModel(nn.Module):
227
  self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device"))
228
 
229
  self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
230
- self.IMAGE_TOKEN_ID = 23903 # token for Comments
231
  self.clip_embed = config.get("clip_embed")
232
  self.phi_embed = config.get("phi_embed")
233
 
@@ -250,49 +252,90 @@ class MainQLoraModel(nn.Module):
250
  self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device")))
251
 
252
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
253
  def forward(self, images, ques, ans):
254
 
255
  batch_size = ques.size()[0]
256
  questions = ques.to(self.config.get("device"))
257
  answers = ans.to(self.config.get("device"))
258
-
259
- questions_embed = peft_model.model.model.embed_tokens(questions)
260
- if images is None:
261
- combined_embeds = questions_embed
 
 
 
 
 
 
 
 
 
 
262
  else:
263
  images = {'pixel_values': images.to(self.config.get("device"))}
264
- clip_outputs = clip_model(**images)
265
  images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token
266
 
267
  # projection
268
- image_embeds = projection(images_embeds).to(torch.float16)
269
- img_token_tensor = torch.tensor(self.IMAGE_TOKEN_ID).repeat(batch_size, 1).to(self.config.get("device"))
270
- img_token_embeds = peft_model.model.model.embed_tokens(img_token_tensor)
271
- combined_embeds = torch.cat([image_embeds, img_token_embeds, questions_embed], dim=1)
272
-
273
- phi_output_logits = peft_model(inputs_embeds=combined_embeds)['logits']
274
-
275
- if images is not None:
276
- # remove image and image token embeddings
277
- phi_output_logits = phi_output_logits[:,images_embeds.shape[1] + 2 : ,:]
278
-
279
- phi_output_logits = phi_output_logits.reshape(-1, self.config.get("vocab_size"))
280
 
281
- loss = cross_entropy(phi_output_logits, answers.contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID, label_smoothing=0.1)
 
 
 
 
 
 
 
282
 
 
 
 
 
283
  return loss
284
 
285
  def validate_model_phase2(model, val_dataloader, tokenizer, config):
286
  model.eval()
287
  total_loss = 0
288
  with torch.no_grad():
289
- try:
290
- for images, ques, ans in tqdm(val_dataloader):
291
- loss = model(images, ques, ans)
292
- total_loss+=loss.item()
293
- print(f"Validation Loss: {total_loss/len(val_dataloader)}")
294
- except Exception as e:
295
- pass
296
  model.train()
297
 
298
 
@@ -310,7 +353,6 @@ def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config):
310
  try:
311
  for idx, (images, ques, ans) in enumerate(pbar):
312
  try:
313
- print("hi")
314
  phi2_optim.zero_grad()
315
  proj_optim.zero_grad()
316
  loss = model(images, ques, ans)
@@ -324,7 +366,7 @@ def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config):
324
  torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
325
  model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
326
  except Exception as e:
327
- print(e)
328
  continue
329
 
330
  validate_model_phase2(model, val_dataloader, tokenizer, config)
 
69
 
70
  batch_size = target_captions.size()[0]
71
  target_length = target_captions.size()[1]
72
+ print("---", target_length)
73
 
74
  # clip model output for image
75
  clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36
 
167
  if (step%1000==0):
168
  torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
169
  except Exception as e:
170
+ print(e)
171
  continue
172
 
173
  # # save model
 
178
  torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
179
 
180
  except Exception as e:
181
+ print(e)
182
  continue
183
 
184
 
 
230
  self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device"))
231
 
232
  self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
 
233
  self.clip_embed = config.get("clip_embed")
234
  self.phi_embed = config.get("phi_embed")
235
 
 
252
  self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device")))
253
 
254
 
255
+ def generate(self, tokenizer, config, images = None, ques = None, max_tokens = 100):
256
+ batch_size = 1
257
+
258
+ predicted_caption = torch.full((batch_size, max_tokens), self.EOS_TOKEN_ID, dtype=torch.long, device=self.config.get('device'))
259
+ start_iq = self.tokenizer.encode("<iQ>")
260
+ end_iq = self.tokenizer.encode("</iQ>")
261
+ start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
262
+ end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
263
+ start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
264
+ end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
265
+ questions_embed = self.phi2_model.model.model.embed_tokens(ques)
266
+ if images is not None:
267
+ clip_outputs = self.clip_model(**images)
268
+ # remove cls token
269
+ images = clip_outputs.last_hidden_state[:, 1:, :]
270
+ image_embeddings = self.projection_layer(images).to(torch.float16)
271
+ combined_embeds = torch.cat([start_iq_embeds, image_embeddings, questions_embed, end_iq_embeds], dim=1)
272
+ else:
273
+ combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds], dim=1)
274
+
275
+ for pos in range(max_tokens - 1):
276
+ model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
277
+ predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
278
+ predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
279
+ predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
280
+ next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
281
+ combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
282
+ return predicted_caption
283
+
284
+
285
  def forward(self, images, ques, ans):
286
 
287
  batch_size = ques.size()[0]
288
  questions = ques.to(self.config.get("device"))
289
  answers = ans.to(self.config.get("device"))
290
+ target_length = ans.size()[1]
291
+ start_iq = self.tokenizer.encode("<iQ>")
292
+ end_iq = self.tokenizer.encode("</iQ>")
293
+ start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
294
+ end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
295
+ start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
296
+ end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
297
+
298
+ questions_embed = self.phi2_model.model.model.embed_tokens(questions)
299
+ answers_embed = self.phi2_model.model.model.embed_tokens(answers)
300
+
301
+ are_all_zeros = torch.all(images == 0).item()
302
+ if are_all_zeros:
303
+ combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
304
  else:
305
  images = {'pixel_values': images.to(self.config.get("device"))}
306
+ clip_outputs = self.clip_model(**images)
307
  images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token
308
 
309
  # projection
310
+ image_embeds = self.projection_layer(images_embeds).to(torch.float16)
311
+ combined_embeds = torch.cat([start_iq_embeds, image_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
 
 
 
 
 
 
 
 
 
 
312
 
313
+ model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
314
+ # # for loss
315
+ loss = 0
316
+ for pos in range(target_length - 1):
317
+ predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
318
+ pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), answers[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
319
+ loss += pos_loss
320
+ loss = loss / target_length
321
 
322
+ # Delete variables to free up memory
323
+ del combined_embeds
324
+ del model_output_logits
325
+ torch.cuda.empty_cache()
326
  return loss
327
 
328
  def validate_model_phase2(model, val_dataloader, tokenizer, config):
329
  model.eval()
330
  total_loss = 0
331
  with torch.no_grad():
332
+ # try:
333
+ for images, ques, ans in tqdm(val_dataloader):
334
+ loss = model(images, ques, ans)
335
+ total_loss+=loss.item()
336
+ print(f"Validation Loss: {total_loss/len(val_dataloader)}")
337
+ # except Exception as e:
338
+ # pass
339
  model.train()
340
 
341
 
 
353
  try:
354
  for idx, (images, ques, ans) in enumerate(pbar):
355
  try:
 
356
  phi2_optim.zero_grad()
357
  proj_optim.zero_grad()
358
  loss = model(images, ques, ans)
 
366
  torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
367
  model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
368
  except Exception as e:
369
+ print("in frp",e)
370
  continue
371
 
372
  validate_model_phase2(model, val_dataloader, tokenizer, config)