Update starvector_arch.py
Browse files- starvector_arch.py +5 -6
starvector_arch.py
CHANGED
|
@@ -159,23 +159,22 @@ class StarVectorForCausalLM(PreTrainedModel):
|
|
| 159 |
if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'):
|
| 160 |
self.model.svg_transformer.gradient_checkpointing_enable()
|
| 161 |
|
| 162 |
-
def forward(self,
|
| 163 |
r"""
|
| 164 |
Wrapper for the forward pass of the model.
|
| 165 |
"""
|
| 166 |
-
device =
|
| 167 |
|
| 168 |
completion_embeds = self.model._get_embeddings(input_ids)
|
| 169 |
-
|
| 170 |
-
attention_mask = torch.ones_like(
|
| 171 |
|
| 172 |
transformer_outputs = self.model.svg_transformer.transformer.transformer(
|
| 173 |
-
inputs_embeds=
|
| 174 |
attention_mask=attention_mask,
|
| 175 |
)
|
| 176 |
hidden_states = transformer_outputs[0]
|
| 177 |
|
| 178 |
-
# If GRPO requested only the last tokens, slice accordingly.
|
| 179 |
if num_logits_to_keep > 0:
|
| 180 |
lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
| 181 |
else:
|
|
|
|
| 159 |
if hasattr(self.model, 'svg_transformer') and hasattr(self.model.svg_transformer, 'gradient_checkpointing_enable'):
|
| 160 |
self.model.svg_transformer.gradient_checkpointing_enable()
|
| 161 |
|
| 162 |
+
def forward(self, vision_embeds, input_ids, num_generations, num_logits_to_keep) -> Union[Tuple, CausalLMOutputWithCrossAttentions]:
|
| 163 |
r"""
|
| 164 |
Wrapper for the forward pass of the model.
|
| 165 |
"""
|
| 166 |
+
device = vision_embeds.device
|
| 167 |
|
| 168 |
completion_embeds = self.model._get_embeddings(input_ids)
|
| 169 |
+
vision_embeds = torch.cat([vision_embeds.repeat(num_generations, 1, 1), completion_embeds], dim=1)
|
| 170 |
+
attention_mask = torch.ones_like(vision_embeds[:, :, 0]).to(device)
|
| 171 |
|
| 172 |
transformer_outputs = self.model.svg_transformer.transformer.transformer(
|
| 173 |
+
inputs_embeds=vision_embeds,
|
| 174 |
attention_mask=attention_mask,
|
| 175 |
)
|
| 176 |
hidden_states = transformer_outputs[0]
|
| 177 |
|
|
|
|
| 178 |
if num_logits_to_keep > 0:
|
| 179 |
lm_logits = self.model.svg_transformer.transformer.lm_head(hidden_states[:, -num_logits_to_keep:, :])
|
| 180 |
else:
|