Update modeling_neollm.py
Browse files- modeling_neollm.py +139 -55
modeling_neollm.py
CHANGED
|
@@ -37,7 +37,7 @@ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
|
|
| 37 |
from transformers.processing_utils import Unpack
|
| 38 |
from transformers.utils import TransformersKwargs, logging
|
| 39 |
from transformers.utils.generic import check_model_inputs
|
| 40 |
-
from
|
| 41 |
|
| 42 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 43 |
|
|
@@ -325,8 +325,6 @@ class SeeDNorm(nn.Module):
|
|
| 325 |
# ==================== STACK MEMORY MODULE ====================
|
| 326 |
class StackMemory(nn.Module):
|
| 327 |
"""
|
| 328 |
-
Differentiable Hidden State Stack for modeling Chomsky hierarchy grammars.
|
| 329 |
-
|
| 330 |
From "Improving Formal Reasoning of Transformer with State Stack":
|
| 331 |
Implements a multi-head differentiable stack with soft push, pop, and no-op operations.
|
| 332 |
Each head maintains its own stack and mask, which are updated based on learned action
|
|
@@ -354,8 +352,8 @@ class StackMemory(nn.Module):
|
|
| 354 |
|
| 355 |
# Dimension reduction projections for efficiency
|
| 356 |
# Uses standard nn.Linear
|
| 357 |
-
self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=
|
| 358 |
-
self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=
|
| 359 |
|
| 360 |
# Action prediction: generates push/pop/no-op probabilities for each head
|
| 361 |
self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
|
|
@@ -365,6 +363,20 @@ class StackMemory(nn.Module):
|
|
| 365 |
|
| 366 |
# Residual weight for gating stack contribution
|
| 367 |
self.res_weight = nn.Parameter(torch.ones(1))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 368 |
|
| 369 |
def _vectorized_update(
|
| 370 |
self,
|
|
@@ -393,8 +405,10 @@ class StackMemory(nn.Module):
|
|
| 393 |
batch_size, seq_len = actions.shape[:2]
|
| 394 |
|
| 395 |
# Expand stack and mask along sequence dimension for parallel processing
|
| 396 |
-
|
| 397 |
-
|
|
|
|
|
|
|
| 398 |
|
| 399 |
# Generate pushed stack: new value at top, shift others down
|
| 400 |
push_stack = torch.cat([
|
|
@@ -476,33 +490,93 @@ class StackMemory(nn.Module):
|
|
| 476 |
new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
|
| 477 |
|
| 478 |
# Global reading via query-over-stack attention
|
| 479 |
-
|
| 480 |
-
# FIX: Project the raw stack content directly.
|
| 481 |
-
# Previously, masking before projection killed gradients for "empty" slots
|
| 482 |
-
# preventing them from ever becoming "full".
|
| 483 |
gate_scores = self.gate_proj(new_stack).squeeze(-1) # [batch, seq, heads, slots]
|
| 484 |
|
| 485 |
-
|
| 486 |
-
# Mask out invalid positions (add large negative value where mask is 0)
|
| 487 |
-
gate_scores = gate_scores + (1 - new_mask) * -1e9
|
| 488 |
-
|
| 489 |
-
# Softmax to get attention weights
|
| 490 |
-
gate_weights = F.softmax(gate_scores, dim=-1)
|
| 491 |
|
| 492 |
# Weighted sum over stack slots
|
| 493 |
-
# new_stack contains the features, gate_weights contains the validity/relevance
|
| 494 |
memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
|
| 495 |
memory_output = memory_output.view(batch_size, seq_len, -1)
|
| 496 |
-
|
| 497 |
-
# Project back to original dimension
|
| 498 |
memory_output = self.up_proj(memory_output)
|
| 499 |
|
| 500 |
-
#
|
| 501 |
output = memory_output * self.res_weight + hidden_states
|
| 502 |
|
| 503 |
-
#
|
|
|
|
|
|
|
|
|
|
| 504 |
return output, new_stack[:, -1], new_mask[:, -1]
|
| 505 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 506 |
# ==================== ROTARY EMBEDDING ====================
|
| 507 |
class NeoLLMRotaryEmbedding(nn.Module):
|
| 508 |
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
|
@@ -1119,8 +1193,8 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1119 |
output_hidden_states: Optional[bool] = None,
|
| 1120 |
output_attentions: Optional[bool] = None,
|
| 1121 |
return_dict: Optional[bool] = None,
|
| 1122 |
-
|
| 1123 |
-
|
| 1124 |
**kwargs: Unpack[TransformersKwargs],
|
| 1125 |
) -> BaseModelOutputWithPast:
|
| 1126 |
output_hidden_states = (
|
|
@@ -1152,6 +1226,7 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1152 |
)
|
| 1153 |
|
| 1154 |
hidden_states = inputs_embeds
|
|
|
|
| 1155 |
all_hidden_states = () if output_hidden_states else None
|
| 1156 |
all_attentions = () if output_attentions else None
|
| 1157 |
|
|
@@ -1161,9 +1236,17 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1161 |
# ResFormer with first-layer feature propagation
|
| 1162 |
self.first_layer_fan = None
|
| 1163 |
|
| 1164 |
-
# Initialize Stack states
|
| 1165 |
-
stack_state =
|
| 1166 |
-
stack_mask =
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1167 |
|
| 1168 |
for decoder_layer in self.layers:
|
| 1169 |
if output_hidden_states:
|
|
@@ -1186,6 +1269,9 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1186 |
all_attentions = all_attentions + (layer_outputs[1],)
|
| 1187 |
|
| 1188 |
if self.use_stack:
|
|
|
|
|
|
|
|
|
|
| 1189 |
stack_state = layer_outputs[2]
|
| 1190 |
stack_mask = layer_outputs[3]
|
| 1191 |
|
|
@@ -1199,18 +1285,13 @@ class NeoLLMModel(NeoLLMPreTrainedModel):
|
|
| 1199 |
|
| 1200 |
if output_hidden_states:
|
| 1201 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
| 1202 |
-
|
| 1203 |
-
# Construct the persistence tuple (Stack only)
|
| 1204 |
-
next_cache = None
|
| 1205 |
-
if self.use_stack:
|
| 1206 |
-
next_cache = (stack_state, stack_mask)
|
| 1207 |
|
| 1208 |
if not return_dict:
|
| 1209 |
-
return tuple(v for v in [hidden_states,
|
| 1210 |
|
| 1211 |
return BaseModelOutputWithPast(
|
| 1212 |
last_hidden_state=hidden_states,
|
| 1213 |
-
past_key_values=
|
| 1214 |
hidden_states=all_hidden_states,
|
| 1215 |
attentions=all_attentions,
|
| 1216 |
)
|
|
@@ -1268,29 +1349,34 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1268 |
def prepare_inputs_for_generation(
|
| 1269 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1270 |
):
|
| 1271 |
-
|
| 1272 |
-
|
| 1273 |
-
past_stack_state = None
|
| 1274 |
-
past_stack_mask = None
|
| 1275 |
-
|
| 1276 |
-
if past_key_values is not None:
|
| 1277 |
-
# We use the past_key_values as a container for our custom states
|
| 1278 |
-
if len(past_key_values) == 2:
|
| 1279 |
-
past_stack_state, past_stack_mask = past_key_values
|
| 1280 |
|
| 1281 |
-
#
|
| 1282 |
-
|
| 1283 |
-
|
| 1284 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1285 |
"input_ids": input_ids,
|
| 1286 |
-
"
|
| 1287 |
-
"past_stack_mask": past_stack_mask,
|
| 1288 |
"use_cache": kwargs.get("use_cache"),
|
| 1289 |
-
"position_ids":
|
| 1290 |
"attention_mask": attention_mask,
|
| 1291 |
"inputs_embeds": inputs_embeds,
|
| 1292 |
}
|
| 1293 |
-
return model_inputs
|
| 1294 |
|
| 1295 |
def forward(
|
| 1296 |
self,
|
|
@@ -1302,8 +1388,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1302 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1303 |
output_hidden_states: Optional[bool] = None,
|
| 1304 |
return_dict: Optional[bool] = None,
|
| 1305 |
-
|
| 1306 |
-
past_stack_mask: Optional[torch.Tensor] = None,
|
| 1307 |
**kwargs: Unpack[TransformersKwargs],
|
| 1308 |
) -> CausalLMOutputWithPast:
|
| 1309 |
outputs: BaseModelOutputWithPast = self.model(
|
|
@@ -1313,8 +1398,7 @@ class NeoLLMForCausalLM(NeoLLMPreTrainedModel, GenerationMixin):
|
|
| 1313 |
inputs_embeds=inputs_embeds,
|
| 1314 |
output_hidden_states=output_hidden_states,
|
| 1315 |
return_dict=return_dict,
|
| 1316 |
-
|
| 1317 |
-
past_stack_mask=past_stack_mask,
|
| 1318 |
**kwargs,
|
| 1319 |
)
|
| 1320 |
|
|
|
|
| 37 |
from transformers.processing_utils import Unpack
|
| 38 |
from transformers.utils import TransformersKwargs, logging
|
| 39 |
from transformers.utils.generic import check_model_inputs
|
| 40 |
+
from configuration_neollm import NeoLLMConfig
|
| 41 |
|
| 42 |
from transformers import AutoConfig, AutoModel, AutoModelForCausalLM
|
| 43 |
|
|
|
|
| 325 |
# ==================== STACK MEMORY MODULE ====================
|
| 326 |
class StackMemory(nn.Module):
|
| 327 |
"""
|
|
|
|
|
|
|
| 328 |
From "Improving Formal Reasoning of Transformer with State Stack":
|
| 329 |
Implements a multi-head differentiable stack with soft push, pop, and no-op operations.
|
| 330 |
Each head maintains its own stack and mask, which are updated based on learned action
|
|
|
|
| 352 |
|
| 353 |
# Dimension reduction projections for efficiency
|
| 354 |
# Uses standard nn.Linear
|
| 355 |
+
self.down_proj = nn.Linear(config.hidden_size, self.stack_d_model, bias=True)
|
| 356 |
+
self.up_proj = nn.Linear(self.stack_d_model, config.hidden_size, bias=True)
|
| 357 |
|
| 358 |
# Action prediction: generates push/pop/no-op probabilities for each head
|
| 359 |
self.action_head = nn.Linear(self.stack_d_model, 3 * self.num_stack_heads, bias=True)
|
|
|
|
| 363 |
|
| 364 |
# Residual weight for gating stack contribution
|
| 365 |
self.res_weight = nn.Parameter(torch.ones(1))
|
| 366 |
+
|
| 367 |
+
# Cache for autoregressive generation (matches OLMo reference)
|
| 368 |
+
self.cache_size = getattr(config, "cache_size", 2048)
|
| 369 |
+
# Initialization fix: Register buffers for cache
|
| 370 |
+
# Default to batch_size=1 if forward_bs is not in config (standard inference)
|
| 371 |
+
forward_bs = getattr(config, 'forward_bs', 1)
|
| 372 |
+
self.register_buffer("k_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, self.head_dim))
|
| 373 |
+
self.register_buffer("action_cache", torch.zeros(forward_bs, self.cache_size, self.num_stack_heads, 3))
|
| 374 |
+
|
| 375 |
+
self.cache_position = 0
|
| 376 |
+
self.enable_cache = False
|
| 377 |
+
|
| 378 |
+
def reset_cache(self):
|
| 379 |
+
self.cache_position = 0
|
| 380 |
|
| 381 |
def _vectorized_update(
|
| 382 |
self,
|
|
|
|
| 405 |
batch_size, seq_len = actions.shape[:2]
|
| 406 |
|
| 407 |
# Expand stack and mask along sequence dimension for parallel processing
|
| 408 |
+
# Only expand if checking against initial state dimensions (4D)
|
| 409 |
+
if stack.dim() == 4:
|
| 410 |
+
stack = stack.unsqueeze(1).expand(-1, seq_len, -1, -1, -1)
|
| 411 |
+
mask = mask.unsqueeze(1).expand(-1, seq_len, -1, -1)
|
| 412 |
|
| 413 |
# Generate pushed stack: new value at top, shift others down
|
| 414 |
push_stack = torch.cat([
|
|
|
|
| 490 |
new_stack, new_mask = self._vectorized_update(stack, mask, actions, k_values)
|
| 491 |
|
| 492 |
# Global reading via query-over-stack attention
|
|
|
|
|
|
|
|
|
|
|
|
|
| 493 |
gate_scores = self.gate_proj(new_stack).squeeze(-1) # [batch, seq, heads, slots]
|
| 494 |
|
| 495 |
+
gate_weights = F.softmax(gate_scores + (1 - new_mask) * -1e9, dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 496 |
|
| 497 |
# Weighted sum over stack slots
|
|
|
|
| 498 |
memory_output = (new_stack * gate_weights.unsqueeze(-1)).sum(dim=3)
|
| 499 |
memory_output = memory_output.view(batch_size, seq_len, -1)
|
| 500 |
+
|
|
|
|
| 501 |
memory_output = self.up_proj(memory_output)
|
| 502 |
|
| 503 |
+
# Residual Connection
|
| 504 |
output = memory_output * self.res_weight + hidden_states
|
| 505 |
|
| 506 |
+
# Update Cache Logic
|
| 507 |
+
if self.enable_cache:
|
| 508 |
+
self._update_cache(k_values.detach(), actions.detach())
|
| 509 |
+
|
| 510 |
return output, new_stack[:, -1], new_mask[:, -1]
|
| 511 |
|
| 512 |
+
def _update_cache(self, k_values: torch.Tensor, actions: torch.Tensor):
|
| 513 |
+
seq_len = k_values.shape[1]
|
| 514 |
+
if self.cache_position + seq_len <= self.cache_size:
|
| 515 |
+
# Assumes standard batch processing for inference (usually batch_size=1)
|
| 516 |
+
self.k_cache[:, self.cache_position:self.cache_position+seq_len] = k_values
|
| 517 |
+
self.action_cache[:, self.cache_position:self.cache_position+seq_len] = actions
|
| 518 |
+
self.cache_position += seq_len
|
| 519 |
+
else:
|
| 520 |
+
self.reset_cache()
|
| 521 |
+
|
| 522 |
+
def step(self, hidden_state: torch.Tensor, stack: torch.Tensor, mask: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 523 |
+
if not self.enable_cache:
|
| 524 |
+
return self.forward(hidden_state.unsqueeze(1), stack, mask)
|
| 525 |
+
|
| 526 |
+
batch_size = hidden_state.shape[0]
|
| 527 |
+
|
| 528 |
+
# Compute features for current token
|
| 529 |
+
new_hidden_states = self.down_proj(hidden_state)
|
| 530 |
+
|
| 531 |
+
action_logits = self.action_head(new_hidden_states) / math.sqrt(self.head_dim)
|
| 532 |
+
current_actions = F.softmax(
|
| 533 |
+
action_logits.view(batch_size, 1, self.num_stack_heads, 3),
|
| 534 |
+
dim=-1
|
| 535 |
+
)
|
| 536 |
+
|
| 537 |
+
current_k = new_hidden_states.view(batch_size, 1, self.num_stack_heads, self.head_dim)
|
| 538 |
+
|
| 539 |
+
# Reconstruct History
|
| 540 |
+
if self.cache_position > 0:
|
| 541 |
+
cached_k = self.k_cache[:, :self.cache_position]
|
| 542 |
+
cached_actions = self.action_cache[:, :self.cache_position]
|
| 543 |
+
|
| 544 |
+
k_values = torch.cat([cached_k, current_k], dim=1)
|
| 545 |
+
actions = torch.cat([cached_actions, current_actions], dim=1)
|
| 546 |
+
else:
|
| 547 |
+
k_values = current_k
|
| 548 |
+
actions = current_actions
|
| 549 |
+
|
| 550 |
+
# Dimension Fix: Pass sequences directly without unsqueeze(0)
|
| 551 |
+
# k_values is [batch, seq_len_total, heads, dim]
|
| 552 |
+
# actions is [batch, seq_len_total, heads, 3]
|
| 553 |
+
|
| 554 |
+
new_stack_seq, new_mask_seq = self._vectorized_update(
|
| 555 |
+
stack, # Initial stack [batch, heads, slots, dim]
|
| 556 |
+
mask,
|
| 557 |
+
actions,
|
| 558 |
+
k_values
|
| 559 |
+
)
|
| 560 |
+
|
| 561 |
+
# Extract last step
|
| 562 |
+
current_stack = new_stack_seq[:, -1]
|
| 563 |
+
current_mask = new_mask_seq[:, -1]
|
| 564 |
+
|
| 565 |
+
gate_scores = self.gate_proj(current_stack).squeeze(-1)
|
| 566 |
+
gate_weights = F.softmax(gate_scores + (1 - current_mask) * -1e9, dim=-1)
|
| 567 |
+
|
| 568 |
+
memory_output = (current_stack * gate_weights.unsqueeze(-1)).sum(dim=2)
|
| 569 |
+
memory_output = memory_output.view(batch_size, -1)
|
| 570 |
+
|
| 571 |
+
memory_output_proj = self.up_proj(memory_output)
|
| 572 |
+
|
| 573 |
+
self._update_cache(current_k, current_actions)
|
| 574 |
+
|
| 575 |
+
return (
|
| 576 |
+
memory_output_proj * self.res_weight + hidden_state,
|
| 577 |
+
current_stack,
|
| 578 |
+
current_mask
|
| 579 |
+
)
|
| 580 |
# ==================== ROTARY EMBEDDING ====================
|
| 581 |
class NeoLLMRotaryEmbedding(nn.Module):
|
| 582 |
inv_freq: torch.Tensor # fix linting for `register_buffer`
|
|
|
|
| 1193 |
output_hidden_states: Optional[bool] = None,
|
| 1194 |
output_attentions: Optional[bool] = None,
|
| 1195 |
return_dict: Optional[bool] = None,
|
| 1196 |
+
past_key_values: Optional[List[torch.FloatTensor]] = None,
|
| 1197 |
+
use_cache: Optional[bool] = None,
|
| 1198 |
**kwargs: Unpack[TransformersKwargs],
|
| 1199 |
) -> BaseModelOutputWithPast:
|
| 1200 |
output_hidden_states = (
|
|
|
|
| 1226 |
)
|
| 1227 |
|
| 1228 |
hidden_states = inputs_embeds
|
| 1229 |
+
next_decoder_cache = None
|
| 1230 |
all_hidden_states = () if output_hidden_states else None
|
| 1231 |
all_attentions = () if output_attentions else None
|
| 1232 |
|
|
|
|
| 1236 |
# ResFormer with first-layer feature propagation
|
| 1237 |
self.first_layer_fan = None
|
| 1238 |
|
| 1239 |
+
# Initialize Stack states (always None at start of forward, rebuilt via cache step or vertical flow)
|
| 1240 |
+
stack_state = None
|
| 1241 |
+
stack_mask = None
|
| 1242 |
+
|
| 1243 |
+
# Propagate use_cache and reset if starting a new sequence
|
| 1244 |
+
if self.use_stack:
|
| 1245 |
+
for layer in self.layers:
|
| 1246 |
+
if hasattr(layer, 'stack_memory'):
|
| 1247 |
+
layer.stack_memory.enable_cache = use_cache if use_cache is not None else False
|
| 1248 |
+
if past_key_values is None:
|
| 1249 |
+
layer.stack_memory.reset_cache()
|
| 1250 |
|
| 1251 |
for decoder_layer in self.layers:
|
| 1252 |
if output_hidden_states:
|
|
|
|
| 1269 |
all_attentions = all_attentions + (layer_outputs[1],)
|
| 1270 |
|
| 1271 |
if self.use_stack:
|
| 1272 |
+
# Vertical memory logic:
|
| 1273 |
+
# The layer returns updated stack for the next layer to use (Vertical passing)
|
| 1274 |
+
# But we do NOT persist it temporally here. The Module's internal cache handles temporal.
|
| 1275 |
stack_state = layer_outputs[2]
|
| 1276 |
stack_mask = layer_outputs[3]
|
| 1277 |
|
|
|
|
| 1285 |
|
| 1286 |
if output_hidden_states:
|
| 1287 |
all_hidden_states = all_hidden_states + (hidden_states,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1288 |
|
| 1289 |
if not return_dict:
|
| 1290 |
+
return tuple(v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_attentions] if v is not None)
|
| 1291 |
|
| 1292 |
return BaseModelOutputWithPast(
|
| 1293 |
last_hidden_state=hidden_states,
|
| 1294 |
+
past_key_values=next_decoder_cache,
|
| 1295 |
hidden_states=all_hidden_states,
|
| 1296 |
attentions=all_attentions,
|
| 1297 |
)
|
|
|
|
| 1349 |
def prepare_inputs_for_generation(
|
| 1350 |
self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs
|
| 1351 |
):
|
| 1352 |
+
if past_key_values:
|
| 1353 |
+
past_length = past_key_values[0][0].shape[2]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1354 |
|
| 1355 |
+
# If past_length > input_ids length, we are likely generating token by token
|
| 1356 |
+
if input_ids.shape[1] > past_length:
|
| 1357 |
+
remove_prefix_length = past_length
|
| 1358 |
+
else:
|
| 1359 |
+
# Default standard HF behavior
|
| 1360 |
+
remove_prefix_length = input_ids.shape[1] - 1
|
| 1361 |
+
|
| 1362 |
+
input_ids = input_ids[:, remove_prefix_length:]
|
| 1363 |
+
|
| 1364 |
+
position_ids = kwargs.get("position_ids", None)
|
| 1365 |
+
if attention_mask is not None and position_ids is None:
|
| 1366 |
+
# create position_ids on the fly for batch generation
|
| 1367 |
+
position_ids = attention_mask.long().cumsum(-1) - 1
|
| 1368 |
+
position_ids.masked_fill_(attention_mask == 0, 1)
|
| 1369 |
+
if past_key_values:
|
| 1370 |
+
position_ids = position_ids[:, -input_ids.shape[1] :]
|
| 1371 |
+
|
| 1372 |
+
return {
|
| 1373 |
"input_ids": input_ids,
|
| 1374 |
+
"past_key_values": past_key_values,
|
|
|
|
| 1375 |
"use_cache": kwargs.get("use_cache"),
|
| 1376 |
+
"position_ids": position_ids,
|
| 1377 |
"attention_mask": attention_mask,
|
| 1378 |
"inputs_embeds": inputs_embeds,
|
| 1379 |
}
|
|
|
|
| 1380 |
|
| 1381 |
def forward(
|
| 1382 |
self,
|
|
|
|
| 1388 |
logits_to_keep: Union[int, torch.Tensor] = 0,
|
| 1389 |
output_hidden_states: Optional[bool] = None,
|
| 1390 |
return_dict: Optional[bool] = None,
|
| 1391 |
+
|
|
|
|
| 1392 |
**kwargs: Unpack[TransformersKwargs],
|
| 1393 |
) -> CausalLMOutputWithPast:
|
| 1394 |
outputs: BaseModelOutputWithPast = self.model(
|
|
|
|
| 1398 |
inputs_embeds=inputs_embeds,
|
| 1399 |
output_hidden_states=output_hidden_states,
|
| 1400 |
return_dict=return_dict,
|
| 1401 |
+
|
|
|
|
| 1402 |
**kwargs,
|
| 1403 |
)
|
| 1404 |
|