asdjghh commited on
Commit
dd6daa0
·
verified ·
1 Parent(s): 84a4854

Upload modeling_vlm.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_vlm.py +5 -1
modeling_vlm.py CHANGED
@@ -330,6 +330,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
330
  hidden_states = outputs.last_hidden_state
331
  logits = self.gen_head(hidden_states)
332
 
 
333
  logits_cond = logits[0::2, :]
334
  logits_uncond = logits[1::2, :]
335
 
@@ -645,7 +646,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
645
  # print('nmsl: ', ii, ind)
646
  if ii % 4 == 0:
647
  offset = ind[1] + 2
648
- inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[(ii // 2) % img_len]
649
 
650
  generated_tokens = torch.zeros((3 * input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda()
651
 
@@ -657,6 +658,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
657
  # torch.save(inputs_embeds, '/data/bxh_data/unify_model/hidden_states.pt')
658
 
659
  logits = self.gen_head(hidden_states)
 
 
660
 
661
  # logits_cond = logits[0::2, :]
662
  # logits_uncond = logits[1::2, :]
@@ -679,6 +682,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
679
  shift_labels = labels[..., 1:].contiguous()
680
  shift_labels = shift_labels.view(-1)
681
  shift_labels = shift_labels.to(shift_logits.device)
 
682
  loss = loss_fct(shift_logits, shift_labels)
683
  else:
684
  loss = None
 
330
  hidden_states = outputs.last_hidden_state
331
  logits = self.gen_head(hidden_states)
332
 
333
+
334
  logits_cond = logits[0::2, :]
335
  logits_uncond = logits[1::2, :]
336
 
 
646
  # print('nmsl: ', ii, ind)
647
  if ii % 4 == 0:
648
  offset = ind[1] + 2
649
+ inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[ii // 4]
650
 
651
  generated_tokens = torch.zeros((3 * input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda()
652
 
 
658
  # torch.save(inputs_embeds, '/data/bxh_data/unify_model/hidden_states.pt')
659
 
660
  logits = self.gen_head(hidden_states)
661
+ print('logits.shape', logits.shape) # [3, 1760, 16384])
662
+ print(labels.shape) # [3, 1760]
663
 
664
  # logits_cond = logits[0::2, :]
665
  # logits_uncond = logits[1::2, :]
 
682
  shift_labels = labels[..., 1:].contiguous()
683
  shift_labels = shift_labels.view(-1)
684
  shift_labels = shift_labels.to(shift_logits.device)
685
+ print(shift_logits.shape, shift_labels.shape)
686
  loss = loss_fct(shift_logits, shift_labels)
687
  else:
688
  loss = None