Upload modeling_vlm.py with huggingface_hub
Browse files- modeling_vlm.py +5 -1
modeling_vlm.py
CHANGED
|
@@ -330,6 +330,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
| 330 |
hidden_states = outputs.last_hidden_state
|
| 331 |
logits = self.gen_head(hidden_states)
|
| 332 |
|
|
|
|
| 333 |
logits_cond = logits[0::2, :]
|
| 334 |
logits_uncond = logits[1::2, :]
|
| 335 |
|
|
@@ -645,7 +646,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
| 645 |
# print('nmsl: ', ii, ind)
|
| 646 |
if ii % 4 == 0:
|
| 647 |
offset = ind[1] + 2
|
| 648 |
-
inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[
|
| 649 |
|
| 650 |
generated_tokens = torch.zeros((3 * input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda()
|
| 651 |
|
|
@@ -657,6 +658,8 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
| 657 |
# torch.save(inputs_embeds, '/data/bxh_data/unify_model/hidden_states.pt')
|
| 658 |
|
| 659 |
logits = self.gen_head(hidden_states)
|
|
|
|
|
|
|
| 660 |
|
| 661 |
# logits_cond = logits[0::2, :]
|
| 662 |
# logits_uncond = logits[1::2, :]
|
|
@@ -679,6 +682,7 @@ class MultiModalityCausalLM(MultiModalityPreTrainedModel):
|
|
| 679 |
shift_labels = labels[..., 1:].contiguous()
|
| 680 |
shift_labels = shift_labels.view(-1)
|
| 681 |
shift_labels = shift_labels.to(shift_logits.device)
|
|
|
|
| 682 |
loss = loss_fct(shift_logits, shift_labels)
|
| 683 |
else:
|
| 684 |
loss = None
|
|
|
|
| 330 |
hidden_states = outputs.last_hidden_state
|
| 331 |
logits = self.gen_head(hidden_states)
|
| 332 |
|
| 333 |
+
|
| 334 |
logits_cond = logits[0::2, :]
|
| 335 |
logits_uncond = logits[1::2, :]
|
| 336 |
|
|
|
|
| 646 |
# print('nmsl: ', ii, ind)
|
| 647 |
if ii % 4 == 0:
|
| 648 |
offset = ind[1] + 2
|
| 649 |
+
inputs_embeds[ind[0], offset: offset + image_embeds_input.shape[1], :] = image_embeds_input[ii // 4]
|
| 650 |
|
| 651 |
generated_tokens = torch.zeros((3 * input_ids.size(0), image_token_num_per_image), dtype=torch.int).cuda()
|
| 652 |
|
|
|
|
| 658 |
# torch.save(inputs_embeds, '/data/bxh_data/unify_model/hidden_states.pt')
|
| 659 |
|
| 660 |
logits = self.gen_head(hidden_states)
|
| 661 |
+
print('logits.shape', logits.shape) # [3, 1760, 16384])
|
| 662 |
+
print(labels.shape) # [3, 1760]
|
| 663 |
|
| 664 |
# logits_cond = logits[0::2, :]
|
| 665 |
# logits_uncond = logits[1::2, :]
|
|
|
|
| 682 |
shift_labels = labels[..., 1:].contiguous()
|
| 683 |
shift_labels = shift_labels.view(-1)
|
| 684 |
shift_labels = shift_labels.to(shift_logits.device)
|
| 685 |
+
print(shift_logits.shape, shift_labels.shape)
|
| 686 |
loss = loss_fct(shift_logits, shift_labels)
|
| 687 |
else:
|
| 688 |
loss = None
|