zR
commited on
Commit
·
62f99e7
1
Parent(s):
a22a67d
- modeling_glm.py +12 -20
modeling_glm.py
CHANGED
|
@@ -31,7 +31,7 @@ from .configuration_glm import GlmConfig
|
|
| 31 |
|
| 32 |
logger = logging.get_logger(__name__)
|
| 33 |
|
| 34 |
-
_CHECKPOINT_FOR_DOC = "THUDM/glm-edge-5b"
|
| 35 |
_CONFIG_FOR_DOC = "GlmConfig"
|
| 36 |
|
| 37 |
|
|
@@ -763,32 +763,24 @@ class GlmModel(GlmPreTrainedModel):
|
|
| 763 |
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
| 764 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 765 |
new_input_embeds = []
|
| 766 |
-
|
| 767 |
-
|
| 768 |
-
|
| 769 |
-
|
| 770 |
image_count = 0
|
| 771 |
for i in range(len(input_ids)):
|
| 772 |
input_id = input_ids[i].tolist()
|
| 773 |
-
if
|
| 774 |
boi_token_pos = input_id.index(self.config.boi_token_id)
|
| 775 |
assert boi_token_pos >= 0, "begin_of_image not found!"
|
| 776 |
num_image_padding_tokens = input_id.count(self.config.boi_token_id)
|
| 777 |
-
assert
|
| 778 |
-
|
| 779 |
-
|
| 780 |
-
|
| 781 |
-
torch.cat(
|
| 782 |
-
(
|
| 783 |
-
inputs_embeds[i, :boi_token_pos],
|
| 784 |
-
images_features[image_count].to(inputs_embeds.device),
|
| 785 |
-
inputs_embeds[i, boi_token_pos + num_image_padding_tokens :],
|
| 786 |
-
)
|
| 787 |
-
)
|
| 788 |
-
)
|
| 789 |
image_count += 1
|
| 790 |
else:
|
| 791 |
-
new_input_embeds.append(inputs_embeds[i])
|
| 792 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
| 793 |
|
| 794 |
if self.gradient_checkpointing and self.training and use_cache:
|
|
@@ -1127,7 +1119,7 @@ class GlmForCausalLM(GlmPreTrainedModel, GenerationMixin):
|
|
| 1127 |
def forward(
|
| 1128 |
self,
|
| 1129 |
input_ids: torch.LongTensor = None,
|
| 1130 |
-
pixel_values: torch.Tensor =
|
| 1131 |
attention_mask: Optional[torch.Tensor] = None,
|
| 1132 |
position_ids: Optional[torch.LongTensor] = None,
|
| 1133 |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|
|
|
|
| 31 |
|
| 32 |
logger = logging.get_logger(__name__)
|
| 33 |
|
| 34 |
+
_CHECKPOINT_FOR_DOC = "THUDM/glm-edge-v-5b"
|
| 35 |
_CONFIG_FOR_DOC = "GlmConfig"
|
| 36 |
|
| 37 |
|
|
|
|
| 763 |
assert input_ids is not None and inputs_embeds is None, f"{input_ids} {inputs_embeds}"
|
| 764 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 765 |
new_input_embeds = []
|
| 766 |
+
boi_token_flags = [True if self.config.boi_token_id in input_id.tolist() else False for input_id in input_ids]
|
| 767 |
+
if is_empty(images):
|
| 768 |
+
images = torch.zeros([1, 3, 672, 672]).to(input_ids.device)
|
| 769 |
+
images_features = self.vision(images).to(device=inputs_embeds.device, dtype=inputs_embeds.dtype)
|
| 770 |
image_count = 0
|
| 771 |
for i in range(len(input_ids)):
|
| 772 |
input_id = input_ids[i].tolist()
|
| 773 |
+
if boi_token_flags[i]:
|
| 774 |
boi_token_pos = input_id.index(self.config.boi_token_id)
|
| 775 |
assert boi_token_pos >= 0, "begin_of_image not found!"
|
| 776 |
num_image_padding_tokens = input_id.count(self.config.boi_token_id)
|
| 777 |
+
assert num_image_padding_tokens == images_features[image_count].shape[0], f"Wrong image padding token number: {num_image_padding_tokens}"
|
| 778 |
+
new_input_embeds.append(torch.cat(
|
| 779 |
+
(inputs_embeds[i, :boi_token_pos], images_features[image_count],
|
| 780 |
+
inputs_embeds[i, boi_token_pos + num_image_padding_tokens:])))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 781 |
image_count += 1
|
| 782 |
else:
|
| 783 |
+
new_input_embeds.append(inputs_embeds[i] + (0 * images_features[0].sum()))
|
| 784 |
inputs_embeds = torch.stack(new_input_embeds, dim=0)
|
| 785 |
|
| 786 |
if self.gradient_checkpointing and self.training and use_cache:
|
|
|
|
| 1119 |
def forward(
|
| 1120 |
self,
|
| 1121 |
input_ids: torch.LongTensor = None,
|
| 1122 |
+
pixel_values: torch.Tensor = torch.zeros([1, 1, 1, 3, 672, 672]),
|
| 1123 |
attention_mask: Optional[torch.Tensor] = None,
|
| 1124 |
position_ids: Optional[torch.LongTensor] = None,
|
| 1125 |
past_key_values: Optional[Union[Cache, List[torch.FloatTensor]]] = None,
|