Spaces:
Sleeping
Sleeping
Update model.py
Browse files
model.py
CHANGED
|
@@ -69,6 +69,7 @@ class CustomClipPhi2(nn.Module):
|
|
| 69 |
|
| 70 |
batch_size = target_captions.size()[0]
|
| 71 |
target_length = target_captions.size()[1]
|
|
|
|
| 72 |
|
| 73 |
# clip model output for image
|
| 74 |
clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36
|
|
@@ -166,6 +167,7 @@ def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer
|
|
| 166 |
if (step%1000==0):
|
| 167 |
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
|
| 168 |
except Exception as e:
|
|
|
|
| 169 |
continue
|
| 170 |
|
| 171 |
# # save model
|
|
@@ -176,6 +178,7 @@ def train_model_phase1(model, train_loader, val_dataloader, optimizer, tokenizer
|
|
| 176 |
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
|
| 177 |
|
| 178 |
except Exception as e:
|
|
|
|
| 179 |
continue
|
| 180 |
|
| 181 |
|
|
@@ -227,7 +230,6 @@ class MainQLoraModel(nn.Module):
|
|
| 227 |
self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device"))
|
| 228 |
|
| 229 |
self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
|
| 230 |
-
self.IMAGE_TOKEN_ID = 23903 # token for Comments
|
| 231 |
self.clip_embed = config.get("clip_embed")
|
| 232 |
self.phi_embed = config.get("phi_embed")
|
| 233 |
|
|
@@ -250,49 +252,90 @@ class MainQLoraModel(nn.Module):
|
|
| 250 |
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device")))
|
| 251 |
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
def forward(self, images, ques, ans):
|
| 254 |
|
| 255 |
batch_size = ques.size()[0]
|
| 256 |
questions = ques.to(self.config.get("device"))
|
| 257 |
answers = ans.to(self.config.get("device"))
|
| 258 |
-
|
| 259 |
-
|
| 260 |
-
|
| 261 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 262 |
else:
|
| 263 |
images = {'pixel_values': images.to(self.config.get("device"))}
|
| 264 |
-
clip_outputs = clip_model(**images)
|
| 265 |
images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token
|
| 266 |
|
| 267 |
# projection
|
| 268 |
-
image_embeds =
|
| 269 |
-
|
| 270 |
-
img_token_embeds = peft_model.model.model.embed_tokens(img_token_tensor)
|
| 271 |
-
combined_embeds = torch.cat([image_embeds, img_token_embeds, questions_embed], dim=1)
|
| 272 |
-
|
| 273 |
-
phi_output_logits = peft_model(inputs_embeds=combined_embeds)['logits']
|
| 274 |
-
|
| 275 |
-
if images is not None:
|
| 276 |
-
# remove image and image token embeddings
|
| 277 |
-
phi_output_logits = phi_output_logits[:,images_embeds.shape[1] + 2 : ,:]
|
| 278 |
-
|
| 279 |
-
phi_output_logits = phi_output_logits.reshape(-1, self.config.get("vocab_size"))
|
| 280 |
|
| 281 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 282 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 283 |
return loss
|
| 284 |
|
| 285 |
def validate_model_phase2(model, val_dataloader, tokenizer, config):
|
| 286 |
model.eval()
|
| 287 |
total_loss = 0
|
| 288 |
with torch.no_grad():
|
| 289 |
-
try:
|
| 290 |
-
|
| 291 |
-
|
| 292 |
-
|
| 293 |
-
|
| 294 |
-
except Exception as e:
|
| 295 |
-
|
| 296 |
model.train()
|
| 297 |
|
| 298 |
|
|
@@ -310,7 +353,6 @@ def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config):
|
|
| 310 |
try:
|
| 311 |
for idx, (images, ques, ans) in enumerate(pbar):
|
| 312 |
try:
|
| 313 |
-
print("hi")
|
| 314 |
phi2_optim.zero_grad()
|
| 315 |
proj_optim.zero_grad()
|
| 316 |
loss = model(images, ques, ans)
|
|
@@ -324,7 +366,7 @@ def train_model_phase2(model, train_loader, val_dataloader, tokenizer, config):
|
|
| 324 |
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
|
| 325 |
model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
|
| 326 |
except Exception as e:
|
| 327 |
-
print(e)
|
| 328 |
continue
|
| 329 |
|
| 330 |
validate_model_phase2(model, val_dataloader, tokenizer, config)
|
|
|
|
| 69 |
|
| 70 |
batch_size = target_captions.size()[0]
|
| 71 |
target_length = target_captions.size()[1]
|
| 72 |
+
print("---", target_length)
|
| 73 |
|
| 74 |
# clip model output for image
|
| 75 |
clip_outputs = self.clip_model(**images) # See this for loading https://huggingface.co/openai/clip-vit-base-patch36
|
|
|
|
| 167 |
if (step%1000==0):
|
| 168 |
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
|
| 169 |
except Exception as e:
|
| 170 |
+
print(e)
|
| 171 |
continue
|
| 172 |
|
| 173 |
# # save model
|
|
|
|
| 178 |
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase1.pth')
|
| 179 |
|
| 180 |
except Exception as e:
|
| 181 |
+
print(e)
|
| 182 |
continue
|
| 183 |
|
| 184 |
|
|
|
|
| 230 |
self.phi2_model = peft.get_peft_model(phi2_model, peft_config).to(config.get("device"))
|
| 231 |
|
| 232 |
self.EOS_TOKEN_ID = self.tokenizer.eos_token_id
|
|
|
|
| 233 |
self.clip_embed = config.get("clip_embed")
|
| 234 |
self.phi_embed = config.get("phi_embed")
|
| 235 |
|
|
|
|
| 252 |
self.projection_layer.load_state_dict(torch.load('./ckpts/model_phase1.pth', map_location=config.get("device")))
|
| 253 |
|
| 254 |
|
| 255 |
+
def generate(self, tokenizer, config, images = None, ques = None, max_tokens = 100):
|
| 256 |
+
batch_size = 1
|
| 257 |
+
|
| 258 |
+
predicted_caption = torch.full((batch_size, max_tokens), self.EOS_TOKEN_ID, dtype=torch.long, device=self.config.get('device'))
|
| 259 |
+
start_iq = self.tokenizer.encode("<iQ>")
|
| 260 |
+
end_iq = self.tokenizer.encode("</iQ>")
|
| 261 |
+
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
|
| 262 |
+
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
|
| 263 |
+
start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
|
| 264 |
+
end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
|
| 265 |
+
questions_embed = self.phi2_model.model.model.embed_tokens(ques)
|
| 266 |
+
if images is not None:
|
| 267 |
+
clip_outputs = self.clip_model(**images)
|
| 268 |
+
# remove cls token
|
| 269 |
+
images = clip_outputs.last_hidden_state[:, 1:, :]
|
| 270 |
+
image_embeddings = self.projection_layer(images).to(torch.float16)
|
| 271 |
+
combined_embeds = torch.cat([start_iq_embeds, image_embeddings, questions_embed, end_iq_embeds], dim=1)
|
| 272 |
+
else:
|
| 273 |
+
combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds], dim=1)
|
| 274 |
+
|
| 275 |
+
for pos in range(max_tokens - 1):
|
| 276 |
+
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
| 277 |
+
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
| 278 |
+
predicted_word_token = torch.argmax(predicted_word_token_logits, dim = -1)
|
| 279 |
+
predicted_caption[:, pos] = predicted_word_token.view(1,-1).to('cpu')
|
| 280 |
+
next_token_embeds = self.phi2_model.model.embed_tokens(predicted_word_token)
|
| 281 |
+
combined_embeds = torch.cat([combined_embeds, next_token_embeds], dim=1)
|
| 282 |
+
return predicted_caption
|
| 283 |
+
|
| 284 |
+
|
| 285 |
def forward(self, images, ques, ans):
|
| 286 |
|
| 287 |
batch_size = ques.size()[0]
|
| 288 |
questions = ques.to(self.config.get("device"))
|
| 289 |
answers = ans.to(self.config.get("device"))
|
| 290 |
+
target_length = ans.size()[1]
|
| 291 |
+
start_iq = self.tokenizer.encode("<iQ>")
|
| 292 |
+
end_iq = self.tokenizer.encode("</iQ>")
|
| 293 |
+
start_iq_embeds = torch.tensor(start_iq).repeat(batch_size, 1)
|
| 294 |
+
end_iq_embeds = torch.tensor(end_iq).repeat(batch_size, 1)
|
| 295 |
+
start_iq_embeds = self.phi2_model.model.model.embed_tokens(start_iq_embeds.to(self.config.get("device")))
|
| 296 |
+
end_iq_embeds = self.phi2_model.model.model.embed_tokens(end_iq_embeds.to(self.config.get("device")))
|
| 297 |
+
|
| 298 |
+
questions_embed = self.phi2_model.model.model.embed_tokens(questions)
|
| 299 |
+
answers_embed = self.phi2_model.model.model.embed_tokens(answers)
|
| 300 |
+
|
| 301 |
+
are_all_zeros = torch.all(images == 0).item()
|
| 302 |
+
if are_all_zeros:
|
| 303 |
+
combined_embeds = torch.cat([start_iq_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
|
| 304 |
else:
|
| 305 |
images = {'pixel_values': images.to(self.config.get("device"))}
|
| 306 |
+
clip_outputs = self.clip_model(**images)
|
| 307 |
images_embeds = clip_outputs.last_hidden_state[:,1:,:] # remove cls token
|
| 308 |
|
| 309 |
# projection
|
| 310 |
+
image_embeds = self.projection_layer(images_embeds).to(torch.float16)
|
| 311 |
+
combined_embeds = torch.cat([start_iq_embeds, image_embeds, questions_embed, end_iq_embeds, answers_embed], dim=1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 312 |
|
| 313 |
+
model_output_logits = self.phi2_model.forward(inputs_embeds = combined_embeds)['logits']
|
| 314 |
+
# # for loss
|
| 315 |
+
loss = 0
|
| 316 |
+
for pos in range(target_length - 1):
|
| 317 |
+
predicted_word_token_logits = model_output_logits[:, -1, :].unsqueeze(1)
|
| 318 |
+
pos_loss = cross_entropy(predicted_word_token_logits.view(-1,predicted_word_token_logits.size(-1)), answers[:, pos].contiguous().view(-1), ignore_index=self.EOS_TOKEN_ID,label_smoothing=0.1)
|
| 319 |
+
loss += pos_loss
|
| 320 |
+
loss = loss / target_length
|
| 321 |
|
| 322 |
+
# Delete variables to free up memory
|
| 323 |
+
del combined_embeds
|
| 324 |
+
del model_output_logits
|
| 325 |
+
torch.cuda.empty_cache()
|
| 326 |
return loss
|
| 327 |
|
| 328 |
def validate_model_phase2(model, val_dataloader, tokenizer, config):
|
| 329 |
model.eval()
|
| 330 |
total_loss = 0
|
| 331 |
with torch.no_grad():
|
| 332 |
+
# try:
|
| 333 |
+
for images, ques, ans in tqdm(val_dataloader):
|
| 334 |
+
loss = model(images, ques, ans)
|
| 335 |
+
total_loss+=loss.item()
|
| 336 |
+
print(f"Validation Loss: {total_loss/len(val_dataloader)}")
|
| 337 |
+
# except Exception as e:
|
| 338 |
+
# pass
|
| 339 |
model.train()
|
| 340 |
|
| 341 |
|
|
|
|
| 353 |
try:
|
| 354 |
for idx, (images, ques, ans) in enumerate(pbar):
|
| 355 |
try:
|
|
|
|
| 356 |
phi2_optim.zero_grad()
|
| 357 |
proj_optim.zero_grad()
|
| 358 |
loss = model(images, ques, ans)
|
|
|
|
| 366 |
torch.save(model.projection_layer.state_dict(), './ckpts/model_phase2.pth')
|
| 367 |
model.phi2_model.save_pretrained('./ckpts/Qlora_adaptor/', save_adapter=True, save_config=True)
|
| 368 |
except Exception as e:
|
| 369 |
+
print("in frp",e)
|
| 370 |
continue
|
| 371 |
|
| 372 |
validate_model_phase2(model, val_dataloader, tokenizer, config)
|