yxing-bj commited on
Commit
9d052a1
·
1 Parent(s): 91940c4

refactor code on modeling_iquestloopcoder

Browse files
Files changed (1) hide show
  1. modeling_iquestloopcoder.py +735 -1043
modeling_iquestloopcoder.py CHANGED
@@ -25,35 +25,114 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
25
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
26
  """
27
 
28
- import math
29
- from typing import Any, List, Optional, Tuple, Union
30
 
31
  import torch
32
- import torch.nn.functional as F
33
- import torch.utils.checkpoint
34
  from torch import nn
35
 
36
  from transformers.activations import ACT2FN
37
- from transformers.cache_utils import Cache, DynamicCache, StaticCache
38
- from transformers.modeling_attn_mask_utils import AttentionMaskConverter
 
 
 
 
 
 
 
 
 
 
 
 
39
  from transformers.modeling_outputs import (
40
  BaseModelOutputWithPast,
41
  CausalLMOutputWithPast,
42
  )
43
- from transformers.modeling_utils import PreTrainedModel
44
- from transformers.generation.utils import GenerationMixin
45
- from transformers.utils import (
46
- add_start_docstrings,
47
- add_start_docstrings_to_model_forward,
48
- logging,
49
- replace_return_docstrings,
50
- )
51
-
52
  from .configuration_iquestloopcoder import IQuestLoopCoderConfig
53
 
54
- logger = logging.get_logger(__name__)
55
 
56
- _CONFIG_FOR_DOC = "IQuestLoopCoderConfig"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
 
59
  class IQuestLoopCoderCache(Cache):
@@ -63,18 +142,19 @@ class IQuestLoopCoderCache(Cache):
63
  - local_key_cache/local_value_cache: Stores KV from Loop 2+ (local window, only window_size tokens)
64
  """
65
 
66
- def __init__(self, window_size: int, num_layers: int):
67
  # We intentionally don't call super().__init__ because the parent assumes static cache sizes.
68
  self.window_size = window_size
69
  self.num_layers = num_layers
 
70
 
71
- # Shared cache: stores Loop 1 KV (global context)
72
- self.shared_key_cache: List[Optional[torch.Tensor]] = [None] * num_layers
73
- self.shared_value_cache: List[Optional[torch.Tensor]] = [None] * num_layers
74
 
75
  # Local cache: stores Loop 2+ KV (sliding window, only window_size tokens)
76
- self.local_key_cache: List[Optional[torch.Tensor]] = [None] * num_layers
77
- self.local_value_cache: List[Optional[torch.Tensor]] = [None] * num_layers
78
 
79
  self.layers: List[Any] = [] # attribute expected by HF Cache utilities
80
  self._seen_tokens = 0
@@ -87,6 +167,9 @@ class IQuestLoopCoderCache(Cache):
87
  cache_kwargs: Optional[dict] = None,
88
  ) -> Tuple[torch.Tensor, torch.Tensor]:
89
  """Update shared cache (Loop 1 KV)."""
 
 
 
90
  if layer_idx < 0 or layer_idx >= self.num_layers:
91
  raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
92
 
@@ -105,7 +188,8 @@ class IQuestLoopCoderCache(Cache):
105
  raise ValueError(
106
  "Cached and incoming key/value tensors must match on batch, head, and head_dim dimensions."
107
  )
108
- assert cached_value is not None
 
109
  self.shared_key_cache[layer_idx] = torch.cat([cached_key, key_states], dim=2)
110
  self.shared_value_cache[layer_idx] = torch.cat([cached_value, value_states], dim=2)
111
 
@@ -126,19 +210,48 @@ class IQuestLoopCoderCache(Cache):
126
  ) -> Tuple[torch.Tensor, torch.Tensor]:
127
  """Update local cache (Loop 2+ KV) with sliding window management.
128
 
129
- If the cache is full (window_size tokens), remove the oldest token and add the new one.
 
 
 
130
  """
 
 
 
131
  if layer_idx < 0 or layer_idx >= self.num_layers:
132
  raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
133
 
134
- cached_key = self.local_key_cache[layer_idx]
135
- cached_value = self.local_value_cache[layer_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
136
 
137
  if cached_key is None:
138
- # First token in local cache
139
- self.local_key_cache[layer_idx] = key_states
140
- self.local_value_cache[layer_idx] = value_states
 
 
 
 
 
 
 
 
141
  else:
 
142
  if (
143
  key_states.shape[0] != cached_key.shape[0]
144
  or key_states.shape[1] != cached_key.shape[1]
@@ -148,35 +261,62 @@ class IQuestLoopCoderCache(Cache):
148
  "Cached and incoming key/value tensors must match on batch, head, and head_dim dimensions."
149
  )
150
  assert cached_value is not None
 
 
 
 
 
151
 
152
- # Check if we need to remove the oldest token
153
- current_len = cached_key.shape[2]
154
- if current_len >= self.window_size:
155
- # Remove the first token (oldest) and add the new one
156
- self.local_key_cache[layer_idx] = torch.cat([cached_key[:, :, 1:, :], key_states], dim=2)
157
- self.local_value_cache[layer_idx] = torch.cat([cached_value[:, :, 1:, :], value_states], dim=2)
158
  else:
159
- # Just append
160
- self.local_key_cache[layer_idx] = torch.cat([cached_key, key_states], dim=2)
161
- self.local_value_cache[layer_idx] = torch.cat([cached_value, value_states], dim=2)
162
 
163
- result_key = self.local_key_cache[layer_idx]
164
- result_value = self.local_value_cache[layer_idx]
165
  assert result_key is not None and result_value is not None
 
 
166
 
167
  return result_key, result_value
168
 
169
- def get_shared(self, layer_idx: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
170
- """Get shared cache for a layer."""
 
 
171
  if layer_idx < 0 or layer_idx >= self.num_layers:
172
- return None, None
173
  return self.shared_key_cache[layer_idx], self.shared_value_cache[layer_idx]
174
 
175
- def get_local(self, layer_idx: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
176
  """Get local cache for a layer."""
 
 
 
177
  if layer_idx < 0 or layer_idx >= self.num_layers:
178
- return None, None
179
- return self.local_key_cache[layer_idx], self.local_value_cache[layer_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
180
 
181
  def update(
182
  self,
@@ -186,18 +326,23 @@ class IQuestLoopCoderCache(Cache):
186
  cache_kwargs: Optional[dict] = None,
187
  ) -> Tuple[torch.Tensor, torch.Tensor]:
188
  """Default update method (for compatibility, updates shared cache)."""
189
- return self.update_shared(key_states, value_states, layer_idx, cache_kwargs)
190
-
 
 
 
 
 
191
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
192
  """Get sequence length from shared cache."""
193
  if layer_idx is None:
194
  layer_idx = 0
195
- if layer_idx < 0 or layer_idx >= len(self.shared_key_cache):
196
  return 0
197
- cached = self.shared_key_cache[layer_idx]
198
- if cached is None:
199
  return 0
200
- return cached.shape[2]
201
 
202
  def get_max_length(self) -> Optional[int]:
203
  return None
@@ -208,17 +353,26 @@ class IQuestLoopCoderCache(Cache):
208
  return self.get_seq_length(layer_idx)
209
 
210
  def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
211
- """Reorder cache for beam search."""
 
 
 
 
 
 
212
  for layer_idx in range(self.num_layers):
213
  if self.shared_key_cache[layer_idx] is not None:
214
  device = self.shared_key_cache[layer_idx].device
215
  self.shared_key_cache[layer_idx] = self.shared_key_cache[layer_idx].index_select(0, beam_idx.to(device))
216
  self.shared_value_cache[layer_idx] = self.shared_value_cache[layer_idx].index_select(0, beam_idx.to(device))
217
-
218
- if self.local_key_cache[layer_idx] is not None:
219
- device = self.local_key_cache[layer_idx].device
220
- self.local_key_cache[layer_idx] = self.local_key_cache[layer_idx].index_select(0, beam_idx.to(device))
221
- self.local_value_cache[layer_idx] = self.local_value_cache[layer_idx].index_select(0, beam_idx.to(device))
 
 
 
222
 
223
  @property
224
  def is_compileable(self) -> bool:
@@ -229,96 +383,39 @@ class IQuestLoopCoderCache(Cache):
229
  logger.debug("Clearing IQuestLoopCoderCache")
230
  self.shared_key_cache = [None] * self.num_layers
231
  self.shared_value_cache = [None] * self.num_layers
232
- self.local_key_cache = [None] * self.num_layers
233
- self.local_value_cache = [None] * self.num_layers
234
  self._seen_tokens = 0
235
 
236
 
237
- class IQuestLoopCoderRMSNorm(nn.Module):
238
- """RMS Normalization layer."""
239
-
240
- def __init__(self, hidden_size, eps=1e-6):
241
- super().__init__()
242
- self.weight = nn.Parameter(torch.ones(hidden_size))
243
- self.variance_epsilon = eps
244
-
245
- def forward(self, hidden_states):
246
- input_dtype = hidden_states.dtype
247
- hidden_states = hidden_states.to(torch.float32)
248
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
249
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
250
- return self.weight * hidden_states.to(input_dtype)
251
-
252
-
253
- class IQuestLoopCoderRotaryEmbedding(nn.Module):
254
- """Rotary Position Embedding (RoPE)."""
255
-
256
- def __init__(self, dim, max_position_embeddings=8192, base=500000.0, device=None, scaling_factor=1.0):
257
- super().__init__()
258
- self.scaling_factor = scaling_factor
259
- self.dim = dim
260
- self.max_position_embeddings = max_position_embeddings
261
- self.base = base
262
- inv_freq = 1.0 / (self.base ** (torch.arange(0, self.dim, 2, dtype=torch.int64).float().to(device) / self.dim))
263
- self.register_buffer("inv_freq", inv_freq, persistent=False)
264
- self.max_seq_len_cached = max_position_embeddings
265
-
266
- @torch.no_grad()
267
- def forward(self, x, position_ids):
268
- # x: [batch_size, num_heads, seq_len, head_dim]
269
- inv_freq_expanded = self.inv_freq[None, :, None].float().expand(position_ids.shape[0], -1, 1)
270
- position_ids_expanded = position_ids[:, None, :].float()
271
-
272
- device_type = x.device.type
273
- with torch.autocast(device_type=device_type, enabled=False):
274
- freqs = (inv_freq_expanded.float() @ position_ids_expanded.float()).transpose(1, 2)
275
- emb = torch.cat((freqs, freqs), dim=-1)
276
- cos = emb.cos()
277
- sin = emb.sin()
278
- return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
279
-
280
-
281
- def rotate_half(x):
282
- """Rotates half the hidden dims of the input."""
283
- x1 = x[..., : x.shape[-1] // 2]
284
- x2 = x[..., x.shape[-1] // 2 :]
285
- return torch.cat((-x2, x1), dim=-1)
286
-
287
-
288
- def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
289
- """Applies Rotary Position Embedding to the query and key tensors."""
290
- cos = cos.unsqueeze(unsqueeze_dim)
291
- sin = sin.unsqueeze(unsqueeze_dim)
292
- q_embed = (q * cos) + (rotate_half(q) * sin)
293
- k_embed = (k * cos) + (rotate_half(k) * sin)
294
- return q_embed, k_embed
295
-
296
-
297
- def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
298
- """Expand KV heads to match query heads for GQA."""
299
- batch, num_key_value_heads, slen, head_dim = hidden_states.shape
300
- if n_rep == 1:
301
- return hidden_states
302
- hidden_states = hidden_states[:, :, None, :, :].expand(batch, num_key_value_heads, n_rep, slen, head_dim)
303
- return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
304
-
305
-
306
- class IQuestLoopCoderMLP(nn.Module):
307
- """MLP with SwiGLU activation."""
308
-
309
- def __init__(self, config):
310
- super().__init__()
311
- self.config = config
312
- self.hidden_size = config.hidden_size
313
- self.intermediate_size = config.intermediate_size
314
- self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
315
- self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=config.mlp_bias)
316
- self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=config.mlp_bias)
317
- self.act_fn = ACT2FN[config.hidden_act]
318
-
319
- def forward(self, x):
320
- return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
321
-
322
 
323
  class LoopGateProjection(nn.Module):
324
  """Gate projection for mixed attention in Loop 2+.
@@ -354,953 +451,554 @@ class LoopGateProjection(nn.Module):
354
  gate = torch.sigmoid(gate_logits)
355
  return gate.unsqueeze(-1) # [batch, num_heads, seq_len, 1]
356
 
357
-
358
  class IQuestLoopCoderAttention(nn.Module):
359
- """Multi-head attention with GQA support."""
360
-
361
- def __init__(self, config: IQuestLoopCoderConfig, layer_idx: Optional[int] = None):
362
  super().__init__()
363
  self.config = config
 
364
  self.layer_idx = layer_idx
365
-
366
- self.hidden_size = config.hidden_size
367
- self.num_heads = config.num_attention_heads
368
- self.head_dim = config.head_dim
369
- self.num_key_value_heads = config.num_key_value_heads
370
- self.num_key_value_groups = self.num_heads // self.num_key_value_heads
371
- self.max_position_embeddings = config.max_position_embeddings
372
- self.rope_theta = config.rope_theta
373
  self.attention_dropout = config.attention_dropout
374
-
375
- self.q_proj = nn.Linear(self.hidden_size, self.num_heads * self.head_dim, bias=config.attention_bias)
376
- self.k_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
377
- self.v_proj = nn.Linear(self.hidden_size, self.num_key_value_heads * self.head_dim, bias=config.attention_bias)
378
- self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=config.attention_bias)
379
-
380
- self.rotary_emb = IQuestLoopCoderRotaryEmbedding(
381
- self.head_dim,
382
- max_position_embeddings=self.max_position_embeddings,
383
- base=self.rope_theta,
 
 
384
  )
385
 
386
  def forward(
387
  self,
388
  hidden_states: torch.Tensor,
389
- attention_mask: Optional[torch.Tensor] = None,
390
- position_ids: Optional[torch.LongTensor] = None,
391
  past_key_value: Optional[Cache] = None,
392
- output_attentions: bool = False,
393
- use_cache: bool = False,
394
  cache_position: Optional[torch.LongTensor] = None,
395
- **kwargs,
396
- ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
397
- bsz, q_len, _ = hidden_states.size()
398
-
399
- query_states = self.q_proj(hidden_states)
400
- key_states = self.k_proj(hidden_states)
401
- value_states = self.v_proj(hidden_states)
402
-
403
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
404
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
405
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
406
-
407
- cos, sin = self.rotary_emb(value_states, position_ids)
408
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
409
 
410
  if past_key_value is not None:
411
- cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
412
- key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
413
-
414
- # Repeat KV for GQA
415
- key_states = repeat_kv(key_states, self.num_key_value_groups)
416
- value_states = repeat_kv(value_states, self.num_key_value_groups)
417
-
418
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
419
-
420
- if attention_mask is not None:
421
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
422
- attn_weights = attn_weights + causal_mask
423
 
424
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
425
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
426
- attn_output = torch.matmul(attn_weights, value_states)
 
 
 
 
 
 
 
 
 
 
 
 
 
427
 
428
- attn_output = attn_output.transpose(1, 2).contiguous()
429
- attn_output = attn_output.reshape(bsz, q_len, -1)
430
  attn_output = self.o_proj(attn_output)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
431
 
432
- return attn_output, attn_weights if output_attentions else None, past_key_value
433
-
434
- def forward_with_external_kv(
435
- self,
436
- hidden_states: torch.Tensor,
437
- external_key: torch.Tensor,
438
- external_value: torch.Tensor,
439
- attention_mask: Optional[torch.Tensor] = None,
440
- position_ids: Optional[torch.LongTensor] = None,
441
- sliding_window: Optional[int] = None,
442
- ) -> torch.Tensor:
443
- """Forward pass using external K, V (for Loop 2+ mixed attention).
444
-
445
- Args:
446
- hidden_states: Input for computing Q
447
- external_key: Pre-computed K (already with RoPE applied)
448
- external_value: Pre-computed V
449
- attention_mask: Causal attention mask
450
- position_ids: Position IDs
451
- sliding_window: If set, apply sliding window attention
 
 
 
 
 
 
 
 
 
 
 
 
 
452
 
453
- Returns:
454
- Attention output [batch, seq_len, num_heads, head_dim]
455
- """
456
- bsz, q_len, _ = hidden_states.size()
457
-
458
- # Compute Q from current hidden states
459
- query_states = self.q_proj(hidden_states)
460
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
461
-
462
- # Apply RoPE to Q
463
- cos, sin = self.rotary_emb(query_states, position_ids)
464
- query_states = (query_states * cos.unsqueeze(1)) + (rotate_half(query_states) * sin.unsqueeze(1))
465
-
466
- # Use external K, V (already have RoPE for K)
467
- key_states = external_key
468
- value_states = external_value
469
-
470
- # Repeat KV for GQA
471
- key_states = repeat_kv(key_states, self.num_key_value_groups)
472
- value_states = repeat_kv(value_states, self.num_key_value_groups)
473
-
474
- # Compute attention
475
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
476
-
477
- # Apply attention mask (causal)
478
- if attention_mask is not None:
479
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
480
- attn_weights = attn_weights + causal_mask
481
-
482
- # Apply sliding window mask if needed
483
- if sliding_window is not None and q_len > sliding_window:
484
- # Create sliding window mask
485
- # For each position i, can only attend to [i-window+1, i]
486
- seq_len = key_states.shape[2]
487
- row_idx = torch.arange(q_len, device=query_states.device).unsqueeze(1)
488
- col_idx = torch.arange(seq_len, device=query_states.device).unsqueeze(0)
489
- window_mask = (col_idx > row_idx) | (col_idx < row_idx - sliding_window + 1)
490
- window_mask = window_mask.unsqueeze(0).unsqueeze(0) # [1, 1, q_len, seq_len]
491
- attn_weights = attn_weights.masked_fill(window_mask, float('-inf'))
492
-
493
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
494
- attn_output = torch.matmul(attn_weights, value_states)
495
-
496
- # Don't apply o_proj here - return raw attention output
497
- attn_output = attn_output.transpose(1, 2).contiguous()
498
- return attn_output # [batch, seq_len, num_heads, head_dim]
499
-
500
- def get_qkv(
501
- self,
502
- hidden_states: torch.Tensor,
503
- position_ids: Optional[torch.LongTensor] = None,
504
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
505
- """Get Q, K, V tensors with RoPE applied.
506
-
507
- Returns:
508
- query: [batch, num_heads, seq_len, head_dim]
509
- key: [batch, num_kv_heads, seq_len, head_dim]
510
- value: [batch, num_kv_heads, seq_len, head_dim]
511
- """
512
- bsz, q_len, _ = hidden_states.size()
513
-
514
- query_states = self.q_proj(hidden_states)
515
- key_states = self.k_proj(hidden_states)
516
- value_states = self.v_proj(hidden_states)
517
-
518
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
519
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
520
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
521
-
522
- cos, sin = self.rotary_emb(value_states, position_ids)
523
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
524
-
525
- return query_states, key_states, value_states
526
-
527
- def forward_decode_loop1(
528
- self,
529
- hidden_states: torch.Tensor,
530
- past_shared_key: Optional[torch.Tensor],
531
- past_shared_value: Optional[torch.Tensor],
532
- attention_mask: Optional[torch.Tensor] = None,
533
- position_ids: Optional[torch.LongTensor] = None,
534
- cache_position: Optional[torch.LongTensor] = None,
535
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
536
- """Forward pass for Loop 1 in decode stage.
537
-
538
- Args:
539
- hidden_states: Current hidden states [batch, 1, hidden_size]
540
- past_shared_key: Past shared keys from cache [batch, num_kv_heads, past_len, head_dim]
541
- past_shared_value: Past shared values from cache [batch, num_kv_heads, past_len, head_dim]
542
- attention_mask: Causal attention mask
543
- position_ids: Position IDs
544
- cache_position: Cache position
545
 
546
- Returns:
547
- output: Attention output [batch, 1, hidden_size]
548
- k1: Current key [batch, num_kv_heads, 1, head_dim] (only current token)
549
- v1: Current value [batch, num_kv_heads, 1, head_dim] (only current token)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
550
  """
551
- bsz, q_len, _ = hidden_states.size()
552
-
553
- query_states = self.q_proj(hidden_states)
554
- key_states = self.k_proj(hidden_states)
555
- value_states = self.v_proj(hidden_states)
556
-
557
- query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
558
- key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
559
- value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
560
-
561
- cos, sin = self.rotary_emb(value_states, position_ids)
562
- query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin)
563
-
564
- # Store current token's k1, v1 for return (before concatenation)
565
- k1_current = key_states # [batch, num_kv_heads, 1, head_dim]
566
- v1_current = value_states # [batch, num_kv_heads, 1, head_dim]
567
-
568
- # Concatenate with past shared KV cache for attention computation
569
- if past_shared_key is not None and past_shared_value is not None:
570
- key_states = torch.cat([past_shared_key, key_states], dim=2)
571
- value_states = torch.cat([past_shared_value, value_states], dim=2)
572
-
573
- # Repeat KV for GQA
574
- key_states = repeat_kv(key_states, self.num_key_value_groups)
575
- value_states = repeat_kv(value_states, self.num_key_value_groups)
576
-
577
- attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim)
578
-
579
- if attention_mask is not None:
580
- causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
581
- attn_weights = attn_weights + causal_mask
582
-
583
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
584
- attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
585
- attn_output = torch.matmul(attn_weights, value_states)
586
-
587
- attn_output = attn_output.transpose(1, 2).contiguous()
588
- attn_output = attn_output.reshape(bsz, q_len, -1)
589
- attn_output = self.o_proj(attn_output)
590
-
591
- return attn_output, k1_current, v1_current
592
-
593
- def forward_decode_loop2(
594
- self,
595
- hidden_states: torch.Tensor,
596
- k1: torch.Tensor,
597
- v1: torch.Tensor,
598
- past_shared_key: Optional[torch.Tensor],
599
- past_shared_value: Optional[torch.Tensor],
600
- past_local_key: Optional[torch.Tensor],
601
- past_local_value: Optional[torch.Tensor],
602
- gate_proj: LoopGateProjection,
603
- attention_mask: Optional[torch.Tensor] = None,
604
- position_ids: Optional[torch.LongTensor] = None,
605
- loop_window_size: int = 64,
606
- ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
607
- """Forward pass for Loop 2 in decode stage with mixed attention.
608
-
609
- Args:
610
- hidden_states: Current hidden states [batch, 1, hidden_size]
611
- k1: Key from Loop 1 (current token) [batch, num_kv_heads, 1, head_dim]
612
- v1: Value from Loop 1 (current token) [batch, num_kv_heads, 1, head_dim]
613
- past_shared_key: Past shared keys from cache [batch, num_kv_heads, past_len, head_dim]
614
- past_shared_value: Past shared values from cache [batch, num_kv_heads, past_len, head_dim]
615
- past_local_key: Past local keys from cache [batch, num_kv_heads, window_len, head_dim]
616
- past_local_value: Past local values from cache [batch, num_kv_heads, window_len, head_dim]
617
- gate_proj: Gate projection module
618
- attention_mask: Causal attention mask
619
- position_ids: Position IDs
620
- loop_window_size: Window size for sliding window attention
621
-
622
- Returns:
623
- output: Attention output [batch, 1, hidden_size]
624
- k2: Current key [batch, num_kv_heads, 1, head_dim]
625
- v2: Current value [batch, num_kv_heads, 1, head_dim]
626
  """
627
- bsz, q_len, _ = hidden_states.size()
628
-
629
- # Get Q2, K2, V2 for current loop
630
- q2, k2, v2 = self.get_qkv(hidden_states, position_ids)
631
-
632
- # Compute gate: g = sigmoid(linear(Q2))
633
- gate = gate_proj(q2) # [batch, num_heads, 1, 1]
634
-
635
- # For attention A: concatenate past shared KV with current k1, v1 (full global context)
636
- if past_shared_key is not None and past_shared_value is not None:
637
- k1_full = torch.cat([past_shared_key, k1], dim=2)
638
- v1_full = torch.cat([past_shared_value, v1], dim=2)
639
- else:
640
- k1_full = k1
641
- v1_full = v1
642
-
643
- # For attention B: concatenate past local KV with current k2, v2 (sliding window)
644
- if past_local_key is not None and past_local_value is not None:
645
- k2_full = torch.cat([past_local_key, k2], dim=2)
646
- v2_full = torch.cat([past_local_value, v2], dim=2)
647
- else:
648
- k2_full = k2
649
- v2_full = v2
650
-
651
- # Repeat KV for GQA
652
- k1_expanded = repeat_kv(k1_full, self.num_key_value_groups)
653
- v1_expanded = repeat_kv(v1_full, self.num_key_value_groups)
654
- k2_expanded = repeat_kv(k2_full, self.num_key_value_groups)
655
- v2_expanded = repeat_kv(v2_full, self.num_key_value_groups)
656
-
657
- # Attention A: Q2 @ K1_full, V1_full (global, full sequence)
658
- head_dim = q2.shape[-1]
659
- attn_weights_A = torch.matmul(q2, k1_expanded.transpose(2, 3)) / math.sqrt(head_dim)
660
- if attention_mask is not None:
661
- causal_mask = attention_mask[:, :, :, : k1_expanded.shape[-2]]
662
- attn_weights_A = attn_weights_A + causal_mask
663
- attn_weights_A = nn.functional.softmax(attn_weights_A, dim=-1, dtype=torch.float32).to(q2.dtype)
664
- attn_A = torch.matmul(attn_weights_A, v1_expanded)
665
-
666
- # Attention B: Q2 @ K2_full, V2_full (local sliding window)
667
- attn_weights_B = torch.matmul(q2, k2_expanded.transpose(2, 3)) / math.sqrt(head_dim)
668
- if attention_mask is not None:
669
- causal_mask = attention_mask[:, :, :, : k2_expanded.shape[-2]]
670
- attn_weights_B = attn_weights_B + causal_mask
671
-
672
- # Apply sliding window mask
673
- q_len_attn = q2.shape[2]
674
- k_len_attn = k2_expanded.shape[2]
675
- if q_len_attn <= loop_window_size:
676
- # If sequence fits in window, use standard attention
677
- attn_weights_B = nn.functional.softmax(attn_weights_B, dim=-1, dtype=torch.float32).to(q2.dtype)
678
- else:
679
- # Apply sliding window mask
680
- row_idx = torch.arange(q_len_attn, device=q2.device).unsqueeze(1)
681
- col_idx = torch.arange(k_len_attn, device=q2.device).unsqueeze(0)
682
- window_mask = (col_idx > row_idx) | (col_idx < row_idx - loop_window_size + 1)
683
- window_mask = window_mask.unsqueeze(0).unsqueeze(0)
684
- attn_weights_B = attn_weights_B.masked_fill(window_mask, float('-inf'))
685
- attn_weights_B = nn.functional.softmax(attn_weights_B, dim=-1, dtype=torch.float32).to(q2.dtype)
686
- attn_B = torch.matmul(attn_weights_B, v2_expanded)
687
-
688
- # Mixed attention: gate * A + (1 - gate) * B
689
- mixed_attn = gate * attn_A + (1 - gate) * attn_B
690
-
691
- # Reshape and apply output projection
692
- bsz, num_heads, seq_len, head_dim = mixed_attn.shape
693
- mixed_attn = mixed_attn.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
694
- attn_output = self.o_proj(mixed_attn)
695
-
696
- return attn_output, k2, v2
697
 
 
 
 
 
 
 
698
 
699
- class IQuestLoopCoderDecoderLayer(nn.Module):
700
- """Transformer decoder layer."""
701
-
 
 
702
  def __init__(self, config: IQuestLoopCoderConfig, layer_idx: int):
703
  super().__init__()
704
  self.hidden_size = config.hidden_size
 
705
  self.self_attn = IQuestLoopCoderAttention(config=config, layer_idx=layer_idx)
 
706
  self.mlp = IQuestLoopCoderMLP(config)
707
  self.input_layernorm = IQuestLoopCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
708
- self.post_attention_layernorm = IQuestLoopCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
709
-
 
 
 
710
  def forward(
711
  self,
712
  hidden_states: torch.Tensor,
 
 
713
  attention_mask: Optional[torch.Tensor] = None,
714
  position_ids: Optional[torch.LongTensor] = None,
715
  past_key_value: Optional[Cache] = None,
716
- output_attentions: Optional[bool] = False,
717
  use_cache: Optional[bool] = False,
718
  cache_position: Optional[torch.LongTensor] = None,
719
- **kwargs,
720
- ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]:
 
 
 
721
  residual = hidden_states
722
  hidden_states = self.input_layernorm(hidden_states)
723
-
724
- hidden_states, self_attn_weights, present_key_value = self.self_attn(
725
  hidden_states=hidden_states,
726
  attention_mask=attention_mask,
727
  position_ids=position_ids,
728
  past_key_value=past_key_value,
729
- output_attentions=output_attentions,
730
  use_cache=use_cache,
731
  cache_position=cache_position,
 
 
 
732
  **kwargs,
733
  )
734
- hidden_states = residual + hidden_states
735
 
736
- residual = hidden_states
737
- hidden_states = self.post_attention_layernorm(hidden_states)
738
- hidden_states = self.mlp(hidden_states)
739
  hidden_states = residual + hidden_states
740
-
741
- outputs = (hidden_states,)
742
- if output_attentions:
743
- outputs += (self_attn_weights,)
744
- if use_cache:
745
- outputs += (present_key_value,)
746
- return outputs
747
-
748
- def forward_loop2_mixed(
749
- self,
750
- hidden_states: torch.Tensor,
751
- k1: torch.Tensor,
752
- v1: torch.Tensor,
753
- gate_proj: LoopGateProjection,
754
- attention_mask: Optional[torch.Tensor] = None,
755
- position_ids: Optional[torch.LongTensor] = None,
756
- loop_window_size: int = 64,
757
- ) -> Tuple[torch.Tensor, float]:
758
- """Forward pass for Loop 2+ with mixed attention.
759
 
760
- Args:
761
- hidden_states: Current hidden states
762
- k1: Key from Loop 1 [batch, num_kv_heads, seq_len, head_dim]
763
- v1: Value from Loop 1 [batch, num_kv_heads, seq_len, head_dim]
764
- gate_proj: Gate projection module for this layer
765
- attention_mask: Causal attention mask
766
- position_ids: Position IDs
767
- loop_window_size: Window size for sliding window attention
768
-
769
- Returns:
770
- output hidden states, gate mean value
771
- """
772
- residual = hidden_states
773
- hidden_states_normed = self.input_layernorm(hidden_states)
774
-
775
- # Get Q2, K2, V2 for current loop
776
- q2, k2, v2 = self.self_attn.get_qkv(hidden_states_normed, position_ids)
777
-
778
- # Compute gate: g = sigmoid(linear(Q2))
779
- # q2: [batch, num_heads, seq_len, head_dim]
780
- gate = gate_proj(q2) # [batch, num_heads, seq_len, 1]
781
- gate_mean = gate.detach().mean().item()
782
-
783
- # Repeat K1, V1 for GQA
784
- k1_expanded = repeat_kv(k1, self.self_attn.num_key_value_groups)
785
- v1_expanded = repeat_kv(v1, self.self_attn.num_key_value_groups)
786
- k2_expanded = repeat_kv(k2, self.self_attn.num_key_value_groups)
787
- v2_expanded = repeat_kv(v2, self.self_attn.num_key_value_groups)
788
-
789
- # Attention A: Q2 @ K1, V1 (global, full sequence)
790
- attn_A = self._compute_attention(q2, k1_expanded, v1_expanded, attention_mask)
791
-
792
- # Attention B: Q2 @ K2, V2 (local sliding window)
793
- attn_B = self._compute_attention_with_window(q2, k2_expanded, v2_expanded, attention_mask, loop_window_size)
794
-
795
- # Mixed attention: gate * A + (1 - gate) * B
796
- # attn_A, attn_B: [batch, num_heads, seq_len, head_dim]
797
- mixed_attn = gate * attn_A + (1 - gate) * attn_B
798
-
799
- # Reshape and apply output projection
800
- bsz, num_heads, seq_len, head_dim = mixed_attn.shape
801
- mixed_attn = mixed_attn.transpose(1, 2).contiguous().reshape(bsz, seq_len, -1)
802
- hidden_states = self.self_attn.o_proj(mixed_attn)
803
-
804
- hidden_states = residual + hidden_states
805
-
806
- # MLP
807
  residual = hidden_states
808
  hidden_states = self.post_attention_layernorm(hidden_states)
809
  hidden_states = self.mlp(hidden_states)
810
  hidden_states = residual + hidden_states
811
-
812
- return hidden_states, gate_mean
813
-
814
- def _compute_attention(
815
- self,
816
- query: torch.Tensor,
817
- key: torch.Tensor,
818
- value: torch.Tensor,
819
- attention_mask: Optional[torch.Tensor],
820
- ) -> torch.Tensor:
821
- """Standard attention computation."""
822
- head_dim = query.shape[-1]
823
- attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_dim)
824
-
825
- if attention_mask is not None:
826
- causal_mask = attention_mask[:, :, :, : key.shape[-2]]
827
- attn_weights = attn_weights + causal_mask
828
-
829
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
830
- attn_output = torch.matmul(attn_weights, value)
831
- return attn_output
832
-
833
- def _compute_attention_with_window(
834
- self,
835
- query: torch.Tensor,
836
- key: torch.Tensor,
837
- value: torch.Tensor,
838
- attention_mask: Optional[torch.Tensor],
839
- window_size: int,
840
- ) -> torch.Tensor:
841
- """Attention with sliding window."""
842
- q_len = query.shape[2]
843
- k_len = key.shape[2]
844
- head_dim = query.shape[-1]
845
-
846
- # If sequence fits in window, use standard attention
847
- if q_len <= window_size:
848
- return self._compute_attention(query, key, value, attention_mask)
849
-
850
- attn_weights = torch.matmul(query, key.transpose(2, 3)) / math.sqrt(head_dim)
851
-
852
- # Apply causal mask
853
- if attention_mask is not None:
854
- causal_mask = attention_mask[:, :, :, : key.shape[-2]]
855
- attn_weights = attn_weights + causal_mask
856
-
857
- # Apply sliding window mask
858
- row_idx = torch.arange(q_len, device=query.device).unsqueeze(1)
859
- col_idx = torch.arange(k_len, device=query.device).unsqueeze(0)
860
- # Can only attend to positions in [i - window_size + 1, i]
861
- window_mask = (col_idx > row_idx) | (col_idx < row_idx - window_size + 1)
862
- window_mask = window_mask.unsqueeze(0).unsqueeze(0)
863
- attn_weights = attn_weights.masked_fill(window_mask, float('-inf'))
864
-
865
- attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query.dtype)
866
- attn_output = torch.matmul(attn_weights, value)
867
- return attn_output
868
 
869
 
 
870
  class IQuestLoopCoderPreTrainedModel(PreTrainedModel):
871
- """Base class for IQuestLoopCoder models."""
872
- config_class = IQuestLoopCoderConfig
873
  base_model_prefix = "model"
874
  supports_gradient_checkpointing = True
875
  _no_split_modules = ["IQuestLoopCoderDecoderLayer"]
876
  _skip_keys_device_placement = ["past_key_values"]
877
- _supports_cache_class = True
878
- _supports_static_cache = True
 
 
 
 
 
 
 
 
879
 
880
- def _init_weights(self, module):
881
- std = self.config.initializer_range
882
- if isinstance(module, nn.Linear):
883
- module.weight.data.normal_(mean=0.0, std=std)
884
- if module.bias is not None:
885
- module.bias.data.zero_()
886
- elif isinstance(module, nn.Embedding):
887
- module.weight.data.normal_(mean=0.0, std=std)
888
- if module.padding_idx is not None:
889
- module.weight.data[module.padding_idx].zero_()
890
 
891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
892
  class IQuestLoopCoderModel(IQuestLoopCoderPreTrainedModel):
893
- """IQuestLoopCoder Transformer decoder model."""
894
-
895
  def __init__(self, config: IQuestLoopCoderConfig):
896
  super().__init__(config)
897
  self.padding_idx = config.pad_token_id
898
  self.vocab_size = config.vocab_size
899
-
900
- self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx)
901
- self.layers = nn.ModuleList([
902
- IQuestLoopCoderDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)
903
- ])
 
 
 
 
 
904
  self.norm = IQuestLoopCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
905
-
 
 
 
 
906
  # Gate projections for Loop 2+ (one per layer)
907
  self.gate_projections = nn.ModuleList([
908
  LoopGateProjection(config.num_attention_heads, config.head_dim)
909
  for _ in range(config.num_hidden_layers)
910
  ])
911
-
912
- # Loop configuration
913
- self.loop_num = config.loop_num
914
- self.loop_window_size = config.loop_window_size
915
-
916
- self.gradient_checkpointing = False
917
- self.post_init()
918
-
919
- def get_input_embeddings(self):
920
- return self.embed_tokens
921
 
922
- def set_input_embeddings(self, value):
923
- self.embed_tokens = value
924
 
 
 
925
  def forward(
926
  self,
927
- input_ids: torch.LongTensor = None,
928
  attention_mask: Optional[torch.Tensor] = None,
929
  position_ids: Optional[torch.LongTensor] = None,
930
  past_key_values: Optional[Cache] = None,
931
  inputs_embeds: Optional[torch.FloatTensor] = None,
932
  use_cache: Optional[bool] = None,
933
- output_attentions: Optional[bool] = None,
934
- output_hidden_states: Optional[bool] = None,
935
- return_dict: Optional[bool] = None,
936
  cache_position: Optional[torch.LongTensor] = None,
937
- ) -> Union[Tuple, BaseModelOutputWithPast]:
938
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
939
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
940
- use_cache = use_cache if use_cache is not None else self.config.use_cache
941
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
 
 
942
 
943
  if inputs_embeds is None:
944
  inputs_embeds = self.embed_tokens(input_ids)
945
 
946
- seq_length = inputs_embeds.shape[1]
947
-
948
- # Determine which forward path to use:
949
- # 1. If past_key_values exists and seq_length == 1: autoregressive generation step
950
- # -> Use standard attention with KV cache (no loop needed for single token)
951
- # 2. Otherwise (prefill or training): use loop mechanism
952
-
953
- is_generation_step = past_key_values is not None and seq_length == 1
954
-
955
- if is_generation_step:
956
- # Autoregressive generation: single token, use KV cache
957
- return self._forward_with_cache(
958
- inputs_embeds=inputs_embeds,
959
- attention_mask=attention_mask,
960
- position_ids=position_ids,
961
- past_key_values=past_key_values,
962
- use_cache=use_cache,
963
- output_attentions=output_attentions,
964
- output_hidden_states=output_hidden_states,
965
- return_dict=return_dict,
966
- cache_position=cache_position,
967
- )
968
-
969
- # Prefill or training: use loop mechanism
970
- return self._forward_loop(
971
- inputs_embeds=inputs_embeds,
972
- attention_mask=attention_mask,
973
- position_ids=position_ids,
974
- output_attentions=output_attentions,
975
- output_hidden_states=output_hidden_states,
976
- return_dict=return_dict,
977
- use_cache=use_cache,
978
- cache_position=cache_position,
979
- )
980
-
981
- def _forward_loop(
982
- self,
983
- inputs_embeds: torch.Tensor,
984
- attention_mask: Optional[torch.Tensor],
985
- position_ids: Optional[torch.LongTensor],
986
- output_attentions: bool,
987
- output_hidden_states: bool,
988
- return_dict: bool,
989
- use_cache: bool = False,
990
- cache_position: Optional[torch.LongTensor] = None,
991
- ) -> Union[Tuple, BaseModelOutputWithPast]:
992
- """Forward with loop mechanism (for training and prefill).
993
-
994
- This implements the Loop mechanism:
995
- - Loop 1: Standard attention, stores K1, V1 for each layer
996
- - Loop 2+: Mixed attention with gated combination of global (K1,V1) and local (K2,V2)
997
- """
998
- batch_size, seq_length, _ = inputs_embeds.shape
999
-
1000
- if position_ids is None:
1001
- device = inputs_embeds.device
1002
- position_ids = torch.arange(seq_length, dtype=torch.long, device=device).unsqueeze(0)
1003
-
1004
- if cache_position is None:
1005
- cache_position = torch.arange(seq_length, device=inputs_embeds.device)
1006
-
1007
- # Create causal mask
1008
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, None, output_attentions)
1009
-
1010
- hidden_states = inputs_embeds
1011
- all_hidden_states = () if output_hidden_states else None
1012
- all_self_attns = () if output_attentions else None
1013
-
1014
- # For KV cache during prefill - use IQuestLoopCoderCache
1015
- # In prefill, past_key_values should be None, so we create a new cache
1016
  if use_cache:
1017
- next_decoder_cache = IQuestLoopCoderCache(self.loop_window_size, len(self.layers))
1018
- else:
1019
- next_decoder_cache = None
1020
-
1021
- # ============ Loop 1: Standard forward, store K1, V1 in shared cache ============
1022
- for layer_idx, decoder_layer in enumerate(self.layers):
1023
- if output_hidden_states:
1024
- all_hidden_states += (hidden_states,)
1025
-
1026
- # Get K1, V1 before standard forward (from original hidden_states, after layernorm)
1027
- hidden_states_normed = decoder_layer.input_layernorm(hidden_states)
1028
- q1, k1, v1 = decoder_layer.self_attn.get_qkv(hidden_states_normed, position_ids)
1029
-
1030
- # Store K1, V1 in shared cache
1031
- if use_cache:
1032
- next_decoder_cache.update_shared(k1, v1, layer_idx)
1033
-
1034
- # Standard forward
1035
- layer_outputs = decoder_layer(
1036
- hidden_states,
1037
- attention_mask=causal_mask,
1038
- position_ids=position_ids,
1039
- past_key_value=None,
1040
- output_attentions=output_attentions,
1041
- use_cache=False,
1042
- )
1043
- hidden_states = layer_outputs[0]
1044
-
1045
- if output_attentions:
1046
- all_self_attns += (layer_outputs[1],)
1047
-
1048
- # ============ Loop 2 to loop_num: Mixed attention, store in local cache ============
1049
- for loop_idx in range(2, self.loop_num + 1):
1050
- for layer_idx, decoder_layer in enumerate(self.layers):
1051
- # Get K1, V1 from shared cache
1052
- k1, v1 = next_decoder_cache.get_shared(layer_idx) if use_cache else (None, None)
1053
- if k1 is None or v1 is None:
1054
- # Fallback: compute K1, V1 if not in cache (shouldn't happen in prefill)
1055
- hidden_states_normed = decoder_layer.input_layernorm(hidden_states)
1056
- _, k1, v1 = decoder_layer.self_attn.get_qkv(hidden_states_normed, position_ids)
1057
-
1058
- gate_proj = self.gate_projections[layer_idx]
1059
-
1060
- hidden_states, gate_mean = decoder_layer.forward_loop2_mixed(
1061
- hidden_states,
1062
- k1=k1,
1063
- v1=v1,
1064
- gate_proj=gate_proj,
1065
- attention_mask=causal_mask,
1066
- position_ids=position_ids,
1067
- loop_window_size=self.loop_window_size,
1068
- )
1069
-
1070
- # Store Loop 2+ KV in local cache (only for loop_idx == 2)
1071
- if use_cache and loop_idx == 2:
1072
- hidden_states_normed = decoder_layer.input_layernorm(hidden_states)
1073
- _, k2, v2 = decoder_layer.self_attn.get_qkv(hidden_states_normed, position_ids)
1074
- next_decoder_cache.update_local(k2, v2, layer_idx)
1075
-
1076
- hidden_states = self.norm(hidden_states)
1077
-
1078
- if output_hidden_states:
1079
- all_hidden_states += (hidden_states,)
1080
-
1081
- if not return_dict:
1082
- return tuple(v for v in [hidden_states, next_decoder_cache, all_hidden_states, all_self_attns] if v is not None)
1083
-
1084
- return BaseModelOutputWithPast(
1085
- last_hidden_state=hidden_states,
1086
- past_key_values=next_decoder_cache,
1087
- hidden_states=all_hidden_states,
1088
- attentions=all_self_attns,
1089
- )
1090
-
1091
- def _forward_with_cache(
1092
- self,
1093
- inputs_embeds: torch.Tensor,
1094
- attention_mask: Optional[torch.Tensor],
1095
- position_ids: Optional[torch.LongTensor],
1096
- past_key_values: Optional[Cache],
1097
- use_cache: bool,
1098
- output_attentions: bool,
1099
- output_hidden_states: bool,
1100
- return_dict: bool,
1101
- cache_position: Optional[torch.LongTensor],
1102
- ) -> Union[Tuple, BaseModelOutputWithPast]:
1103
- """Forward with KV cache using loop mechanism (for inference generation).
1104
-
1105
- Loop 1: Standard attention, uses shared KV cache (previous tokens + current token)
1106
- Loop 2+: Mixed attention, uses local KV cache (sliding window)
1107
- """
1108
- batch_size, seq_length, _ = inputs_embeds.shape
1109
-
1110
  if cache_position is None:
1111
- past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
1112
- cache_position = torch.arange(past_seen_tokens, past_seen_tokens + seq_length, device=inputs_embeds.device)
1113
-
 
 
 
 
 
 
1114
  if position_ids is None:
1115
  position_ids = cache_position.unsqueeze(0)
1116
-
1117
- causal_mask = self._update_causal_mask(attention_mask, inputs_embeds, cache_position, past_key_values, output_attentions)
1118
-
1119
- # Ensure we're using IQuestLoopCoderCache
1120
- if use_cache:
1121
- if not isinstance(past_key_values, IQuestLoopCoderCache):
1122
- # Convert to IQuestLoopCoderCache if needed
1123
- next_decoder_cache = IQuestLoopCoderCache(self.loop_window_size, len(self.layers))
1124
- # Copy existing cache if possible
1125
- if past_key_values is not None:
1126
- for layer_idx in range(len(self.layers)):
1127
- try:
1128
- past_k = past_key_values.key_cache[layer_idx] if hasattr(past_key_values, 'key_cache') else None
1129
- past_v = past_key_values.value_cache[layer_idx] if hasattr(past_key_values, 'value_cache') else None
1130
- if past_k is not None and past_v is not None:
1131
- next_decoder_cache.update_shared(past_k, past_v, layer_idx)
1132
- except:
1133
- pass
1134
- else:
1135
- next_decoder_cache = past_key_values
1136
- else:
1137
- next_decoder_cache = None
1138
-
1139
  hidden_states = inputs_embeds
1140
- all_hidden_states = () if output_hidden_states else None
1141
- all_self_attns = () if output_attentions else None
1142
-
1143
- # ============ Loop 1: Standard attention, store in shared cache ============
1144
- for layer_idx, decoder_layer in enumerate(self.layers):
1145
- if output_hidden_states:
1146
- all_hidden_states += (hidden_states,)
1147
-
1148
- # Get past shared KV cache
1149
- past_shared_key, past_shared_value = None, None
1150
- if next_decoder_cache is not None:
1151
- past_shared_key, past_shared_value = next_decoder_cache.get_shared(layer_idx)
1152
-
1153
- # Forward Loop 1
1154
- attn_output, k1, v1 = decoder_layer.self_attn.forward_decode_loop1(
1155
- hidden_states=decoder_layer.input_layernorm(hidden_states),
1156
- past_shared_key=past_shared_key,
1157
- past_shared_value=past_shared_value,
1158
- attention_mask=causal_mask,
1159
- position_ids=position_ids,
1160
- cache_position=cache_position,
1161
- )
1162
-
1163
- # Update shared cache with current token's Loop 1 KV
1164
- if use_cache:
1165
- next_decoder_cache.update_shared(k1, v1, layer_idx)
1166
-
1167
- hidden_states = hidden_states + attn_output
1168
-
1169
- # MLP
1170
- residual = hidden_states
1171
- hidden_states = decoder_layer.post_attention_layernorm(hidden_states)
1172
- hidden_states = decoder_layer.mlp(hidden_states)
1173
- hidden_states = residual + hidden_states
1174
 
1175
- if output_attentions:
1176
- all_self_attns += (None,) # We don't return attention weights in decode loop
1177
-
1178
- # ============ Loop 2 to loop_num: Mixed attention, store in local cache ============
1179
- # Store k1, v1 from Loop 1 for use in Loop 2+
1180
- loop1_kv = []
1181
- for layer_idx in range(len(self.layers)):
1182
- if next_decoder_cache is not None:
1183
- k1_full, v1_full = next_decoder_cache.get_shared(layer_idx)
1184
- if k1_full is not None and v1_full is not None:
1185
- # Get only the last token (current token)
1186
- loop1_kv.append((k1_full[:, :, -1:, :], v1_full[:, :, -1:, :], k1_full, v1_full))
1187
- else:
1188
- loop1_kv.append((None, None, None, None))
1189
- else:
1190
- loop1_kv.append((None, None, None, None))
1191
-
1192
- for loop_idx in range(2, self.loop_num + 1):
1193
- for layer_idx, decoder_layer in enumerate(self.layers):
1194
- # Get k1, v1 (current token's Loop 1 KV) and full shared cache
1195
- k1_current, v1_current, k1_full, v1_full = loop1_kv[layer_idx]
1196
- if k1_current is None or v1_current is None:
1197
- continue
1198
-
1199
- # Get past local KV cache
1200
- past_local_key, past_local_value = None, None
1201
- if next_decoder_cache is not None:
1202
- past_local_key, past_local_value = next_decoder_cache.get_local(layer_idx)
1203
-
1204
- gate_proj = self.gate_projections[layer_idx]
1205
-
1206
- # Forward Loop 2+
1207
- attn_output, k2, v2 = decoder_layer.self_attn.forward_decode_loop2(
1208
- hidden_states=decoder_layer.input_layernorm(hidden_states),
1209
- k1=k1_current,
1210
- v1=v1_current,
1211
- past_shared_key=k1_full[:, :, :-1, :] if k1_full is not None and k1_full.shape[2] > 1 else None,
1212
- past_shared_value=v1_full[:, :, :-1, :] if v1_full is not None and v1_full.shape[2] > 1 else None,
1213
- past_local_key=past_local_key,
1214
- past_local_value=past_local_value,
1215
- gate_proj=gate_proj,
1216
- attention_mask=causal_mask,
1217
  position_ids=position_ids,
1218
- loop_window_size=self.loop_window_size,
 
 
 
 
1219
  )
1220
-
1221
- # Update local cache with current token's Loop 2+ KV
1222
- if use_cache and loop_idx == 2:
1223
- next_decoder_cache.update_local(k2, v2, layer_idx)
1224
-
1225
- hidden_states = hidden_states + attn_output
1226
-
1227
- # MLP
1228
- residual = hidden_states
1229
- hidden_states = decoder_layer.post_attention_layernorm(hidden_states)
1230
- hidden_states = decoder_layer.mlp(hidden_states)
1231
- hidden_states = residual + hidden_states
1232
-
1233
  hidden_states = self.norm(hidden_states)
1234
-
1235
- if output_hidden_states:
1236
- all_hidden_states += (hidden_states,)
1237
-
1238
- next_cache = next_decoder_cache if use_cache else None
1239
-
1240
- if not return_dict:
1241
- return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None)
1242
-
1243
- return BaseModelOutputWithPast(
1244
- last_hidden_state=hidden_states,
1245
- past_key_values=next_cache,
1246
- hidden_states=all_hidden_states,
1247
- attentions=all_self_attns,
1248
  )
1249
-
1250
- def _update_causal_mask(
1251
- self,
1252
- attention_mask: torch.Tensor,
1253
- input_tensor: torch.Tensor,
1254
- cache_position: torch.Tensor,
1255
- past_key_values: Cache,
1256
- output_attentions: bool,
1257
- ):
1258
- """Create causal attention mask."""
1259
- dtype, device = input_tensor.dtype, input_tensor.device
1260
- min_dtype = torch.finfo(dtype).min
1261
- sequence_length = input_tensor.shape[1]
1262
-
1263
- # Determine target length for attention
1264
- if past_key_values is not None:
1265
- # For DynamicCache: use get_seq_length() to get cached length
1266
- # target_length = cached_length + current_sequence_length
1267
- past_length = past_key_values.get_seq_length()
1268
- target_length = past_length + sequence_length
1269
- elif attention_mask is not None:
1270
- target_length = attention_mask.shape[-1]
1271
- else:
1272
- target_length = sequence_length
1273
-
1274
- # Create causal mask
1275
- causal_mask = torch.full((sequence_length, target_length), fill_value=min_dtype, dtype=dtype, device=device)
1276
- if sequence_length != 1:
1277
- # For prefill: standard causal mask
1278
- causal_mask = torch.triu(causal_mask, diagonal=1)
1279
-
1280
- # Adjust for cache position (for generation steps after prefill)
1281
- causal_mask *= torch.arange(target_length, device=device) > cache_position.reshape(-1, 1)
1282
- causal_mask = causal_mask[None, None, :, :].expand(input_tensor.shape[0], 1, -1, -1)
1283
-
1284
- if attention_mask is not None:
1285
- causal_mask = causal_mask.clone()
1286
- mask_length = attention_mask.shape[-1]
1287
- if mask_length <= target_length:
1288
- padding_mask = causal_mask[:, :, :, :mask_length] + attention_mask[:, None, None, :]
1289
- padding_mask = padding_mask == 0
1290
- causal_mask[:, :, :, :mask_length] = causal_mask[:, :, :, :mask_length].masked_fill(padding_mask, min_dtype)
1291
-
1292
- return causal_mask
1293
 
1294
 
 
1295
  class IQuestLoopCoderForCausalLM(IQuestLoopCoderPreTrainedModel, GenerationMixin):
1296
- """IQuestLoopCoder model with a causal language modeling head."""
1297
  _tied_weights_keys = ["lm_head.weight"]
 
 
1298
 
1299
  def __init__(self, config):
1300
  super().__init__(config)
1301
  self.model = IQuestLoopCoderModel(config)
1302
  self.vocab_size = config.vocab_size
1303
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
 
 
 
 
1304
  self.post_init()
1305
 
1306
  def get_input_embeddings(self):
@@ -1321,42 +1019,80 @@ class IQuestLoopCoderForCausalLM(IQuestLoopCoderPreTrainedModel, GenerationMixin
1321
  def get_decoder(self):
1322
  return self.model
1323
 
 
 
1324
  def forward(
1325
  self,
1326
- input_ids: torch.LongTensor = None,
1327
  attention_mask: Optional[torch.Tensor] = None,
1328
  position_ids: Optional[torch.LongTensor] = None,
1329
  past_key_values: Optional[Cache] = None,
1330
  inputs_embeds: Optional[torch.FloatTensor] = None,
1331
  labels: Optional[torch.LongTensor] = None,
1332
  use_cache: Optional[bool] = None,
1333
- output_attentions: Optional[bool] = None,
1334
- output_hidden_states: Optional[bool] = None,
1335
- return_dict: Optional[bool] = None,
1336
  cache_position: Optional[torch.LongTensor] = None,
1337
- ) -> Union[Tuple, CausalLMOutputWithPast]:
1338
- output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions
1339
- output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states
1340
- return_dict = return_dict if return_dict is not None else self.config.use_return_dict
1341
 
1342
- outputs = self.model(
1343
  input_ids=input_ids,
1344
  attention_mask=attention_mask,
1345
  position_ids=position_ids,
1346
  past_key_values=past_key_values,
1347
  inputs_embeds=inputs_embeds,
1348
  use_cache=use_cache,
1349
- output_attentions=output_attentions,
1350
- output_hidden_states=output_hidden_states,
1351
- return_dict=return_dict,
1352
  cache_position=cache_position,
 
 
 
 
 
 
1353
  )
1354
 
1355
- hidden_states = outputs[0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1356
  logits = self.lm_head(hidden_states)
1357
  logits = logits.float()
1358
 
1359
- loss = None
1360
  if labels is not None:
1361
  shift_logits = logits[..., :-1, :].contiguous()
1362
  shift_labels = labels[..., 1:].contiguous()
@@ -1366,11 +1102,7 @@ class IQuestLoopCoderForCausalLM(IQuestLoopCoderPreTrainedModel, GenerationMixin
1366
  shift_labels = shift_labels.to(shift_logits.device)
1367
  loss = loss_fct(shift_logits, shift_labels)
1368
 
1369
- if not return_dict:
1370
- output = (logits,) + outputs[1:]
1371
- return (loss,) + output if loss is not None else output
1372
-
1373
- return CausalLMOutputWithPast(
1374
  loss=loss,
1375
  logits=logits,
1376
  past_key_values=outputs.past_key_values,
@@ -1378,44 +1110,4 @@ class IQuestLoopCoderForCausalLM(IQuestLoopCoderPreTrainedModel, GenerationMixin
1378
  attentions=outputs.attentions,
1379
  )
1380
 
1381
- def prepare_inputs_for_generation(
1382
- self,
1383
- input_ids,
1384
- past_key_values=None,
1385
- attention_mask=None,
1386
- inputs_embeds=None,
1387
- cache_position=None,
1388
- use_cache=True,
1389
- **kwargs,
1390
- ):
1391
- past_length = 0
1392
- if past_key_values is not None:
1393
- past_length = past_key_values.get_seq_length()
1394
- if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]:
1395
- input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
1396
- elif past_length < input_ids.shape[1]:
1397
- input_ids = input_ids[:, past_length:]
1398
-
1399
- if cache_position is None:
1400
- cache_position = torch.arange(past_length, past_length + input_ids.shape[1], device=input_ids.device)
1401
- elif use_cache:
1402
- cache_position = cache_position[-input_ids.shape[1]:]
1403
-
1404
- position_ids = cache_position.unsqueeze(0)
1405
-
1406
- if inputs_embeds is not None and past_key_values is None:
1407
- model_inputs = {"inputs_embeds": inputs_embeds}
1408
- else:
1409
- model_inputs = {"input_ids": input_ids.contiguous()}
1410
-
1411
- model_inputs.update(
1412
- {
1413
- "position_ids": position_ids,
1414
- "cache_position": cache_position,
1415
- "past_key_values": past_key_values,
1416
- "use_cache": use_cache,
1417
- "attention_mask": attention_mask,
1418
- }
1419
- )
1420
- return model_inputs
1421
-
 
25
  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
26
  """
27
 
28
+ import logging
29
+ from typing import Any, Callable, Optional, Union, Tuple, List
30
 
31
  import torch
 
 
32
  from torch import nn
33
 
34
  from transformers.activations import ACT2FN
35
+ from transformers.cache_utils import Cache
36
+ from transformers.generation import GenerationMixin
37
+ from transformers.integrations import use_kernel_forward_from_hub
38
+ from transformers.masking_utils import (
39
+ create_causal_mask,
40
+ create_sliding_window_causal_mask,
41
+ )
42
+ from transformers.modeling_flash_attention_utils import FlashAttentionKwargs
43
+ from transformers.modeling_layers import (
44
+ GenericForQuestionAnswering,
45
+ GenericForSequenceClassification,
46
+ GenericForTokenClassification,
47
+ GradientCheckpointingLayer,
48
+ )
49
  from transformers.modeling_outputs import (
50
  BaseModelOutputWithPast,
51
  CausalLMOutputWithPast,
52
  )
53
+ from transformers.modeling_rope_utils import ROPE_INIT_FUNCTIONS, dynamic_rope_update
54
+ from transformers.modeling_utils import ALL_ATTENTION_FUNCTIONS, PreTrainedModel
55
+ from transformers.processing_utils import Unpack
56
+ from transformers.utils import TransformersKwargs, auto_docstring, can_return_tuple
57
+ from transformers.utils.generic import check_model_inputs
 
 
 
 
58
  from .configuration_iquestloopcoder import IQuestLoopCoderConfig
59
 
 
60
 
61
+ logger = logging.getLogger(__name__)
62
+
63
+
64
+ def needs_iquestloopcoder_cache(
65
+ cache: Optional[Cache]
66
+ ) -> bool:
67
+ # need to test more conditions
68
+ if cache is None:
69
+ return True
70
+ if isinstance(cache, IQuestLoopCoderCache):
71
+ return False
72
+ return True
73
+
74
+ class IQuestLoopCoderMLP(nn.Module):
75
+ def __init__(self, config):
76
+ super().__init__()
77
+ self.config = config
78
+ self.hidden_size = config.hidden_size
79
+ self.intermediate_size = config.intermediate_size
80
+ self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
81
+ self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
82
+ self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
83
+ self.act_fn = ACT2FN[config.hidden_act]
84
+
85
+ def forward(self, x):
86
+ down_proj = self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x))
87
+ return down_proj
88
+
89
+
90
+ def rotate_half(x):
91
+ """Rotates half the hidden dims of the input."""
92
+ x1 = x[..., : x.shape[-1] // 2]
93
+ x2 = x[..., x.shape[-1] // 2 :]
94
+ return torch.cat((-x2, x1), dim=-1)
95
+
96
+
97
+ def apply_rotary_pos_emb(q, k, cos, sin, position_ids=None, unsqueeze_dim=1):
98
+ """Applies Rotary Position Embedding to the query and key tensors.
99
+
100
+ Args:
101
+ q (`torch.Tensor`): The query tensor.
102
+ k (`torch.Tensor`): The key tensor.
103
+ cos (`torch.Tensor`): The cosine part of the rotary embedding.
104
+ sin (`torch.Tensor`): The sine part of the rotary embedding.
105
+ position_ids (`torch.Tensor`, *optional*):
106
+ Deprecated and unused.
107
+ unsqueeze_dim (`int`, *optional*, defaults to 1):
108
+ The 'unsqueeze_dim' argument specifies the dimension along which to unsqueeze cos[position_ids] and
109
+ sin[position_ids] so that they can be properly broadcasted to the dimensions of q and k. For example, note
110
+ that cos[position_ids] and sin[position_ids] have the shape [batch_size, seq_len, head_dim]. Then, if q and
111
+ k have the shape [batch_size, heads, seq_len, head_dim], then setting unsqueeze_dim=1 makes
112
+ cos[position_ids] and sin[position_ids] broadcastable to the shapes of q and k. Similarly, if q and k have
113
+ the shape [batch_size, seq_len, heads, head_dim], then set unsqueeze_dim=2.
114
+ Returns:
115
+ `tuple(torch.Tensor)` comprising of the query and key tensors rotated using the Rotary Position Embedding.
116
+ """
117
+ cos = cos.unsqueeze(unsqueeze_dim)
118
+ sin = sin.unsqueeze(unsqueeze_dim)
119
+ q_embed = (q * cos) + (rotate_half(q) * sin)
120
+ k_embed = (k * cos) + (rotate_half(k) * sin)
121
+ return q_embed, k_embed
122
+
123
+
124
+ def repeat_kv(hidden_states: torch.Tensor, n_rep: int) -> torch.Tensor:
125
+ """
126
+ This is the equivalent of torch.repeat_interleave(x, dim=1, repeats=n_rep). The hidden states go from (batch,
127
+ num_key_value_heads, seqlen, head_dim) to (batch, num_attention_heads, seqlen, head_dim)
128
+ """
129
+ batch, num_key_value_heads, slen, head_dim = hidden_states.shape
130
+ if n_rep == 1:
131
+ return hidden_states
132
+ hidden_states = hidden_states[:, :, None, :, :].expand(
133
+ batch, num_key_value_heads, n_rep, slen, head_dim
134
+ )
135
+ return hidden_states.reshape(batch, num_key_value_heads * n_rep, slen, head_dim)
136
 
137
 
138
  class IQuestLoopCoderCache(Cache):
 
142
  - local_key_cache/local_value_cache: Stores KV from Loop 2+ (local window, only window_size tokens)
143
  """
144
 
145
+ def __init__(self, window_size: int, num_layers: int, loop_num: int=2):
146
  # We intentionally don't call super().__init__ because the parent assumes static cache sizes.
147
  self.window_size = window_size
148
  self.num_layers = num_layers
149
+ self.loop_num = loop_num
150
 
151
+ # Shared cache: stores Loop 1 KV (global context)
152
+ self.shared_key_cache: List[Optional[torch.Tensor]] = [None] * self.num_layers
153
+ self.shared_value_cache: List[Optional[torch.Tensor]] = [None] * self.num_layers
154
 
155
  # Local cache: stores Loop 2+ KV (sliding window, only window_size tokens)
156
+ self.local_key_cache: List[Optional[torch.Tensor]] = [None] * (self.loop_num-1) * self.num_layers
157
+ self.local_value_cache: List[Optional[torch.Tensor]] = [None] * (self.loop_num-1) * self.num_layers
158
 
159
  self.layers: List[Any] = [] # attribute expected by HF Cache utilities
160
  self._seen_tokens = 0
 
167
  cache_kwargs: Optional[dict] = None,
168
  ) -> Tuple[torch.Tensor, torch.Tensor]:
169
  """Update shared cache (Loop 1 KV)."""
170
+ # only store the first loop's kv cache
171
+ loop_idx = cache_kwargs.get("loop_idx", 0)
172
+ assert loop_idx == 0
173
  if layer_idx < 0 or layer_idx >= self.num_layers:
174
  raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
175
 
 
188
  raise ValueError(
189
  "Cached and incoming key/value tensors must match on batch, head, and head_dim dimensions."
190
  )
191
+ assert key_states.shape[2] == 1
192
+ assert value_states.shape[2] == 1
193
  self.shared_key_cache[layer_idx] = torch.cat([cached_key, key_states], dim=2)
194
  self.shared_value_cache[layer_idx] = torch.cat([cached_value, value_states], dim=2)
195
 
 
210
  ) -> Tuple[torch.Tensor, torch.Tensor]:
211
  """Update local cache (Loop 2+ KV) with sliding window management.
212
 
213
+ Ensures the local cache always contains at most window_size tokens.
214
+ Local cache only stores loop_idx > 0 (i.e., loop_idx = 1, 2, ...).
215
+ For loop_idx = 1, cache_idx = layer_idx + 0 * num_layers = layer_idx (0 to num_layers-1)
216
+ For loop_idx = 2, cache_idx = layer_idx + 1 * num_layers (num_layers to 2*num_layers-1)
217
  """
218
+ # only store the local kv cache for loop_idx > 0
219
+ loop_idx = cache_kwargs.get("loop_idx", 0)
220
+ assert loop_idx > 0, f"update_local should only be called for loop_idx > 0, got {loop_idx}"
221
  if layer_idx < 0 or layer_idx >= self.num_layers:
222
  raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
223
 
224
+ # Local cache size is (loop_num-1) * num_layers
225
+ # loop_idx = 1 maps to indices 0 to num_layers-1
226
+ # loop_idx = 2 maps to indices num_layers to 2*num_layers-1
227
+ # So offset = (loop_idx - 1) * num_layers
228
+ cache_idx = layer_idx + (loop_idx - 1) * self.num_layers
229
+
230
+ # Validate cache_idx is within bounds
231
+ max_cache_idx = (self.loop_num - 1) * self.num_layers
232
+ if cache_idx >= max_cache_idx:
233
+ raise IndexError(
234
+ f"cache_idx {cache_idx} out of range. "
235
+ f"loop_idx={loop_idx}, layer_idx={layer_idx}, "
236
+ f"max_cache_idx={max_cache_idx - 1}"
237
+ )
238
+ cached_key = self.local_key_cache[cache_idx]
239
+ cached_value = self.local_value_cache[cache_idx]
240
 
241
  if cached_key is None:
242
+ # First token in local cache, for prefill
243
+ # If prefill sequence is longer than window_size, only keep the last window_size tokens
244
+ seq_len = key_states.shape[2]
245
+ if seq_len > self.window_size:
246
+ # Keep only the last window_size tokens
247
+ start_idx = seq_len - self.window_size
248
+ self.local_key_cache[cache_idx] = key_states[:, :, start_idx:, :]
249
+ self.local_value_cache[cache_idx] = value_states[:, :, start_idx:, :]
250
+ else:
251
+ self.local_key_cache[cache_idx] = key_states
252
+ self.local_value_cache[cache_idx] = value_states
253
  else:
254
+ # store the local kv cache for decode
255
  if (
256
  key_states.shape[0] != cached_key.shape[0]
257
  or key_states.shape[1] != cached_key.shape[1]
 
261
  "Cached and incoming key/value tensors must match on batch, head, and head_dim dimensions."
262
  )
263
  assert cached_value is not None
264
+ assert key_states.shape[2] == 1
265
+ assert value_states.shape[2] == 1
266
+ # Concatenate new tokens
267
+ new_key = torch.cat([cached_key, key_states], dim=2)
268
+ new_value = torch.cat([cached_value, value_states], dim=2)
269
 
270
+ # Ensure the total length doesn't exceed window_size
271
+ total_len = new_key.shape[2]
272
+ if total_len > self.window_size:
273
+ # Keep only the last window_size tokens
274
+ self.local_key_cache[cache_idx] = new_key[:, :, -self.window_size:, :]
275
+ self.local_value_cache[cache_idx] = new_value[:, :, -self.window_size:, :]
276
  else:
277
+ self.local_key_cache[cache_idx] = new_key
278
+ self.local_value_cache[cache_idx] = new_value
 
279
 
280
+ result_key = self.local_key_cache[cache_idx]
281
+ result_value = self.local_value_cache[cache_idx]
282
  assert result_key is not None and result_value is not None
283
+ # Ensure the result is at most window_size (can be less during prefill when sequence is shorter)
284
+ assert result_key.shape[2] <= self.window_size, f"Local cache size {result_key.shape[2]} exceeds window_size {self.window_size}"
285
 
286
  return result_key, result_value
287
 
288
+ def get_shared(self, layer_idx: int|List[int]) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
289
+ """Get shared cache for some layer."""
290
+ if isinstance(layer_idx, list):
291
+ return [self.get_shared(layer_idx) for layer_idx in layer_idx]
292
  if layer_idx < 0 or layer_idx >= self.num_layers:
293
+ raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
294
  return self.shared_key_cache[layer_idx], self.shared_value_cache[layer_idx]
295
 
296
+ def get_local(self, layer_idx: int|List[int], loop_idx: int) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
297
  """Get local cache for a layer."""
298
+ assert loop_idx > 0, f"get_local should only be called for loop_idx > 0, got {loop_idx}"
299
+ if isinstance(layer_idx, list):
300
+ return [self.get_local(layer_idx, loop_idx) for layer_idx in layer_idx]
301
  if layer_idx < 0 or layer_idx >= self.num_layers:
302
+ raise ValueError(f"layer_idx must be in [0, {self.num_layers}), got {layer_idx}")
303
+
304
+ # Local cache size is (loop_num-1) * num_layers
305
+ # loop_idx = 1 maps to indices 0 to num_layers-1
306
+ # loop_idx = 2 maps to indices num_layers to 2*num_layers-1
307
+ # So offset = (loop_idx - 1) * num_layers
308
+ cache_idx = layer_idx + (loop_idx - 1) * self.num_layers
309
+
310
+ # Validate cache_idx is within bounds
311
+ max_cache_idx = (self.loop_num - 1) * self.num_layers
312
+ if cache_idx >= max_cache_idx:
313
+ raise IndexError(
314
+ f"cache_idx {cache_idx} out of range. "
315
+ f"loop_idx={loop_idx}, layer_idx={layer_idx}, "
316
+ f"max_cache_idx={max_cache_idx - 1}"
317
+ )
318
+
319
+ return self.local_key_cache[cache_idx], self.local_value_cache[cache_idx]
320
 
321
  def update(
322
  self,
 
326
  cache_kwargs: Optional[dict] = None,
327
  ) -> Tuple[torch.Tensor, torch.Tensor]:
328
  """Default update method (for compatibility, updates shared cache)."""
329
+ loop_idx = cache_kwargs.get("loop_idx", 0)
330
+ assert loop_idx < self.loop_num
331
+ if loop_idx == 0:
332
+ return self.update_shared(key_states, value_states, layer_idx, cache_kwargs)
333
+ else:
334
+ return self.update_local(key_states, value_states, layer_idx, cache_kwargs)
335
+
336
  def get_seq_length(self, layer_idx: Optional[int] = 0) -> int:
337
  """Get sequence length from shared cache."""
338
  if layer_idx is None:
339
  layer_idx = 0
340
+ if layer_idx < 0 or layer_idx >= self.loop_num * self.num_layers:
341
  return 0
342
+ cached_key = self.shared_key_cache[layer_idx]
343
+ if cached_key is None:
344
  return 0
345
+ return cached_key.shape[2]
346
 
347
  def get_max_length(self) -> Optional[int]:
348
  return None
 
353
  return self.get_seq_length(layer_idx)
354
 
355
  def reorder_cache(self, beam_idx: torch.LongTensor) -> None:
356
+ # pass
357
+ raise NotImplementedError("Reorder cache for beam search is not implemented")
358
+ """Reorder cache for beam search.
359
+
360
+ Reorders both shared cache (Loop 1) and local cache (Loop 2+) according to beam_idx.
361
+ """
362
+ # Reorder shared cache (Loop 1, loop_idx=0)
363
  for layer_idx in range(self.num_layers):
364
  if self.shared_key_cache[layer_idx] is not None:
365
  device = self.shared_key_cache[layer_idx].device
366
  self.shared_key_cache[layer_idx] = self.shared_key_cache[layer_idx].index_select(0, beam_idx.to(device))
367
  self.shared_value_cache[layer_idx] = self.shared_value_cache[layer_idx].index_select(0, beam_idx.to(device))
368
+
369
+ # Reorder local cache (Loop 2+, loop_idx > 0)
370
+ # Local cache size is (loop_num-1) * num_layers
371
+ for cache_idx in range(len(self.local_key_cache)):
372
+ if self.local_key_cache[cache_idx] is not None:
373
+ device = self.local_key_cache[cache_idx].device
374
+ self.local_key_cache[cache_idx] = self.local_key_cache[cache_idx].index_select(0, beam_idx.to(device))
375
+ self.local_value_cache[cache_idx] = self.local_value_cache[cache_idx].index_select(0, beam_idx.to(device))
376
 
377
  @property
378
  def is_compileable(self) -> bool:
 
383
  logger.debug("Clearing IQuestLoopCoderCache")
384
  self.shared_key_cache = [None] * self.num_layers
385
  self.shared_value_cache = [None] * self.num_layers
386
+ self.local_key_cache = [None] * self.num_layers * (self.loop_num-1)
387
+ self.local_value_cache = [None] * self.num_layers * (self.loop_num-1)
388
  self._seen_tokens = 0
389
 
390
 
391
+ def eager_attention_forward(
392
+ module: nn.Module,
393
+ query: torch.Tensor,
394
+ key: torch.Tensor,
395
+ value: torch.Tensor,
396
+ attention_mask: Optional[torch.Tensor],
397
+ scaling: float,
398
+ dropout: float = 0.0,
399
+ **kwargs: Unpack[TransformersKwargs],
400
+ ):
401
+ key_states = repeat_kv(key, module.num_key_value_groups)
402
+ value_states = repeat_kv(value, module.num_key_value_groups)
403
+
404
+ attn_weights = torch.matmul(query, key_states.transpose(2, 3)) * scaling
405
+ if attention_mask is not None:
406
+ causal_mask = attention_mask[:, :, :, : key_states.shape[-2]]
407
+ attn_weights = attn_weights + causal_mask
408
+
409
+ attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(
410
+ query.dtype
411
+ )
412
+ attn_weights = nn.functional.dropout(
413
+ attn_weights, p=dropout, training=module.training
414
+ )
415
+ attn_output = torch.matmul(attn_weights, value_states)
416
+ attn_output = attn_output.transpose(1, 2).contiguous()
417
+
418
+ return attn_output, attn_weights
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
419
 
420
  class LoopGateProjection(nn.Module):
421
  """Gate projection for mixed attention in Loop 2+.
 
451
  gate = torch.sigmoid(gate_logits)
452
  return gate.unsqueeze(-1) # [batch, num_heads, seq_len, 1]
453
 
 
454
  class IQuestLoopCoderAttention(nn.Module):
455
+ """Multi-headed attention from 'Attention Is All You Need' paper"""
456
+
457
+ def __init__(self, config: IQuestLoopCoderConfig, layer_idx: int):
458
  super().__init__()
459
  self.config = config
460
+ assert layer_idx >= 0 and layer_idx < config.num_hidden_layers
461
  self.layer_idx = layer_idx
462
+
463
+ self.head_dim = getattr(
464
+ config, "head_dim", config.hidden_size // config.num_attention_heads
465
+ )
466
+ self.num_key_value_groups = (
467
+ config.num_attention_heads // config.num_key_value_heads
468
+ )
469
+ self.scaling = self.head_dim**-0.5
470
  self.attention_dropout = config.attention_dropout
471
+ self.is_causal = True
472
+ self.q_proj = nn.Linear(
473
+ config.hidden_size, config.num_attention_heads * self.head_dim, bias=False
474
+ )
475
+ self.k_proj = nn.Linear(
476
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
477
+ )
478
+ self.v_proj = nn.Linear(
479
+ config.hidden_size, config.num_key_value_heads * self.head_dim, bias=False
480
+ )
481
+ self.o_proj = nn.Linear(
482
+ config.num_attention_heads * self.head_dim, config.hidden_size, bias=False
483
  )
484
 
485
  def forward(
486
  self,
487
  hidden_states: torch.Tensor,
488
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
489
+ attention_mask: Optional[torch.Tensor],
490
  past_key_value: Optional[Cache] = None,
 
 
491
  cache_position: Optional[torch.LongTensor] = None,
492
+ loop_idx: int = 0,
493
+ gate_proj: Optional[LoopGateProjection] = None,
494
+ **kwargs: Unpack[FlashAttentionKwargs],
495
+ ) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
496
+ if loop_idx == 0:
497
+ return self.forward_loop1(hidden_states, loop_idx, position_embeddings, attention_mask, past_key_value, cache_position, **kwargs)
498
+ else:
499
+ return self.forward_loop2(hidden_states, loop_idx, position_embeddings, attention_mask, past_key_value, cache_position, gate_proj, **kwargs)
500
+
501
+ def forward_loop1(
502
+ self,
503
+ hidden_states: torch.Tensor,
504
+ loop_idx: int,
505
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
506
+ attention_mask: Optional[torch.Tensor],
507
+ past_key_value: Optional[IQuestLoopCoderCache] = None,
508
+ cache_position: Optional[torch.LongTensor] = None,
509
+ **kwargs: Unpack[FlashAttentionKwargs]) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
510
+ input_shape = hidden_states.shape[:-1]
511
+ hidden_shape = (*input_shape, -1, self.head_dim)
512
+
513
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
514
+ key_states = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
515
+ value_states = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
516
+
517
+ cos, sin = position_embeddings
518
+ query_states, key_states = apply_rotary_pos_emb(
519
+ query_states, key_states, cos, sin
520
+ )
521
 
522
  if past_key_value is not None:
523
+ # sin and cos are specific to RoPE models; cache_position needed for the static cache
524
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "loop_idx": loop_idx}
525
+ key_states, value_states = past_key_value.update(
526
+ key_states,
527
+ value_states,
528
+ self.layer_idx,
529
+ cache_kwargs,
530
+ )
 
 
 
 
531
 
532
+ attention_interface: Callable = eager_attention_forward
533
+ if self.config._attn_implementation != "eager":
534
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
535
+ self.config._attn_implementation
536
+ ]
537
+
538
+ attn_output, attn_weights = attention_interface(
539
+ self,
540
+ query_states,
541
+ key_states,
542
+ value_states,
543
+ attention_mask,
544
+ dropout=0.0 if not self.training else self.attention_dropout,
545
+ scaling=self.scaling,
546
+ **kwargs,
547
+ )
548
 
549
+ attn_output = attn_output.reshape(*input_shape, -1).contiguous()
 
550
  attn_output = self.o_proj(attn_output)
551
+ return attn_output, (attn_weights)
552
+
553
+
554
+ def forward_loop2(
555
+ self,
556
+ hidden_states: torch.Tensor,
557
+ loop_idx: int,
558
+ position_embeddings: tuple[torch.Tensor, torch.Tensor],
559
+ attention_mask: Optional[torch.Tensor],
560
+ past_key_value: Optional[IQuestLoopCoderCache] = None,
561
+ cache_position: Optional[torch.LongTensor] = None,
562
+ gate_proj: Optional[LoopGateProjection] = None,
563
+ **kwargs: Unpack[FlashAttentionKwargs]) -> tuple[torch.Tensor, Optional[torch.Tensor], Optional[tuple[torch.Tensor]]]:
564
+
565
+ input_shape = hidden_states.shape[:-1]
566
+ hidden_shape = (*input_shape, -1, self.head_dim)
567
+
568
+ query_states = self.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
569
+ key_states_local = self.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
570
+ value_states_local = self.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
571
+
572
+ cos, sin = position_embeddings
573
+ query_states, key_states_local = apply_rotary_pos_emb(
574
+ query_states, key_states_local, cos, sin
575
+ )
576
 
577
+ key_states_share, value_states_share = None, None
578
+ if past_key_value is not None:
579
+ # get key_share, value_share from past_key_value
580
+ cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position, "loop_idx": loop_idx}
581
+ key_states_share, value_states_share = past_key_value.get_shared(self.layer_idx)
582
+ key_states_local, value_states_local = past_key_value.update(
583
+ key_states_local,
584
+ value_states_local,
585
+ self.layer_idx,
586
+ cache_kwargs,
587
+ )
588
+
589
+ attention_interface: Callable = eager_attention_forward
590
+ if self.config._attn_implementation != "eager":
591
+ attention_interface = ALL_ATTENTION_FUNCTIONS[
592
+ self.config._attn_implementation
593
+ ]
594
+
595
+ # Create masks for global and local attention
596
+ # Global attention: full causal mask (can see all tokens in shared cache)
597
+ # Local attention: causal mask for local window (can only see window_size tokens in local cache)
598
+ attention_mask_global = attention_mask # Use full causal mask for global attention
599
+
600
+ # For local attention, create a mask that matches the local cache size
601
+ # The local cache already contains only the last window_size tokens,
602
+ # so we need a causal mask that allows attention within this window
603
+ attention_mask_local = None
604
+ if key_states_local is not None and value_states_local is not None:
605
+ # Local cache has shape [batch, num_heads, local_seq_len, head_dim]
606
+ # where local_seq_len <= window_size
607
+ local_seq_len = key_states_local.shape[2]
608
+ bsz = query_states.shape[0]
609
+ q_len = query_states.shape[2]
610
 
611
+ # Create a causal mask for local attention
612
+ # This allows each query position to attend to all positions up to and including itself
613
+ # within the local window (which is already the last window_size tokens)
614
+ device = query_states.device
615
+ dtype = query_states.dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
616
 
617
+ if attention_mask is not None:
618
+ # If we have a global mask, we need to adapt it for local attention
619
+ # The global mask shape is [batch, 1, q_len, global_kv_len]
620
+ # For local attention, we only need the last local_seq_len positions
621
+ global_kv_len = attention_mask.shape[-1]
622
+
623
+ if global_kv_len >= local_seq_len:
624
+ # Extract the last local_seq_len columns from the global mask
625
+ # This represents attention to the last window_size tokens
626
+ attention_mask_local = attention_mask[..., -local_seq_len:]
627
+ else:
628
+ # If global mask is shorter than local_seq_len, create a simple causal mask
629
+ # This can happen during prefill when local cache is being built
630
+ attention_mask_local = torch.triu(
631
+ torch.ones((q_len, local_seq_len), device=device, dtype=dtype) * float("-inf"),
632
+ diagonal=1
633
+ ).unsqueeze(0).expand(bsz, -1, -1, -1) # [batch, 1, q_len, local_seq_len]
634
+ else:
635
+ # No global mask provided, create a simple causal mask for local attention
636
+ # This allows full attention within the local window (causal)
637
+ attention_mask_local = torch.triu(
638
+ torch.ones((q_len, local_seq_len), device=device, dtype=dtype) * float("-inf"),
639
+ diagonal=1
640
+ ).unsqueeze(0).expand(bsz, -1, -1, -1) # [batch, 1, q_len, local_seq_len]
641
+
642
+ # global attn: attend to all tokens in shared cache
643
+ attn_output_global, attn_weights_global = attention_interface(
644
+ self,
645
+ query_states,
646
+ key_states_share,
647
+ value_states_share,
648
+ attention_mask_global,
649
+ dropout=0.0 if not self.training else self.attention_dropout,
650
+ scaling=self.scaling,
651
+ **kwargs,
652
+ )
653
+
654
+ # local attn: attend only to tokens in local cache (window_size)
655
+ attn_output_local, attn_weights_local = attention_interface(
656
+ self,
657
+ query_states,
658
+ key_states_local,
659
+ value_states_local,
660
+ attention_mask_local,
661
+ dropout=0.0 if not self.training else self.attention_dropout,
662
+ scaling=self.scaling,
663
+ **kwargs,
664
+ )
665
+
666
+ # attention_interface returns [batch, seq_len, num_heads, head_dim] for eager_attention_forward
667
+ # but Flash Attention might return [batch, num_heads, seq_len, head_dim]
668
+ # We need [batch, num_heads, seq_len, head_dim] to match gate shape
669
+ q_len = query_states.shape[2] # Query sequence length
670
+ num_heads = query_states.shape[1]
671
+
672
+ # Normalize attn_output_global to [batch, num_heads, q_len, head_dim]
673
+ if attn_output_global.dim() == 4:
674
+ # Check if shape is [batch, seq_len, num_heads, head_dim] (eager) or [batch, num_heads, seq_len, head_dim] (flash)
675
+ if attn_output_global.shape[1] == q_len:
676
+ # Shape is [batch, seq_len, num_heads, head_dim], transpose to [batch, num_heads, seq_len, head_dim]
677
+ attn_output_global = attn_output_global.transpose(1, 2)
678
+ # Ensure sequence length matches query length (take first q_len tokens)
679
+ if attn_output_global.shape[2] > q_len:
680
+ attn_output_global = attn_output_global[:, :, :q_len, :]
681
+ elif attn_output_global.shape[2] < q_len:
682
+ # This shouldn't happen, but handle it gracefully
683
+ raise ValueError(f"attn_output_global seq_len {attn_output_global.shape[2]} < q_len {q_len}")
684
+
685
+ # Normalize attn_output_local to [batch, num_heads, q_len, head_dim]
686
+ if attn_output_local.dim() == 4:
687
+ # Check if shape is [batch, seq_len, num_heads, head_dim] (eager) or [batch, num_heads, seq_len, head_dim] (flash)
688
+ if attn_output_local.shape[1] == q_len:
689
+ # Shape is [batch, seq_len, num_heads, head_dim], transpose to [batch, num_heads, seq_len, head_dim]
690
+ attn_output_local = attn_output_local.transpose(1, 2)
691
+ # Ensure sequence length matches query length (take first q_len tokens)
692
+ if attn_output_local.shape[2] > q_len:
693
+ attn_output_local = attn_output_local[:, :, :q_len, :]
694
+ elif attn_output_local.shape[2] < q_len:
695
+ # This shouldn't happen, but handle it gracefully
696
+ raise ValueError(f"attn_output_local seq_len {attn_output_local.shape[2]} < q_len {q_len}")
697
+
698
+ assert gate_proj is not None
699
+ gate = gate_proj(query_states) # [batch, num_heads, seq_len, 1]
700
+ mixed_attn_output = attn_output_local * (1 - gate) + attn_output_global * gate
701
+
702
+ mixed_attn_output = mixed_attn_output.reshape(*input_shape, -1).contiguous()
703
+ mixed_attn_output = self.o_proj(mixed_attn_output)
704
+ return mixed_attn_output, (attn_weights_global, attn_weights_local, attn_output_global, attn_output_local, gate)
705
+
706
+
707
+ @use_kernel_forward_from_hub("RMSNorm")
708
+ class IQuestLoopCoderRMSNorm(nn.Module):
709
+ def __init__(self, hidden_size, eps=1e-6):
710
  """
711
+ IQuestLoopCoderRMSNorm is equivalent to T5LayerNorm
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
712
  """
713
+ super().__init__()
714
+ self.weight = nn.Parameter(torch.ones(hidden_size))
715
+ self.variance_epsilon = eps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
716
 
717
+ def forward(self, hidden_states):
718
+ input_dtype = hidden_states.dtype
719
+ hidden_states = hidden_states.to(torch.float32)
720
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
721
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
722
+ return self.weight * hidden_states.to(input_dtype)
723
 
724
+ def extra_repr(self):
725
+ return f"{tuple(self.weight.shape)}, eps={self.variance_epsilon}"
726
+
727
+
728
+ class IQuestLoopCoderDecoderLayer(GradientCheckpointingLayer):
729
  def __init__(self, config: IQuestLoopCoderConfig, layer_idx: int):
730
  super().__init__()
731
  self.hidden_size = config.hidden_size
732
+
733
  self.self_attn = IQuestLoopCoderAttention(config=config, layer_idx=layer_idx)
734
+
735
  self.mlp = IQuestLoopCoderMLP(config)
736
  self.input_layernorm = IQuestLoopCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
737
+ self.post_attention_layernorm = IQuestLoopCoderRMSNorm(
738
+ config.hidden_size, eps=config.rms_norm_eps
739
+ )
740
+ self.layer_idx = layer_idx
741
+
742
  def forward(
743
  self,
744
  hidden_states: torch.Tensor,
745
+ loop_idx: int = 0,
746
+ gate_proj: Optional[LoopGateProjection] = None,
747
  attention_mask: Optional[torch.Tensor] = None,
748
  position_ids: Optional[torch.LongTensor] = None,
749
  past_key_value: Optional[Cache] = None,
 
750
  use_cache: Optional[bool] = False,
751
  cache_position: Optional[torch.LongTensor] = None,
752
+ position_embeddings: Optional[
753
+ tuple[torch.Tensor, torch.Tensor]
754
+ ] = None, # necessary, but kept here for BC
755
+ **kwargs: Unpack[TransformersKwargs],
756
+ ) -> tuple[torch.Tensor]:
757
  residual = hidden_states
758
  hidden_states = self.input_layernorm(hidden_states)
759
+ # Self Attention
760
+ hidden_states, _ = self.self_attn(
761
  hidden_states=hidden_states,
762
  attention_mask=attention_mask,
763
  position_ids=position_ids,
764
  past_key_value=past_key_value,
 
765
  use_cache=use_cache,
766
  cache_position=cache_position,
767
+ loop_idx=loop_idx,
768
+ position_embeddings=position_embeddings,
769
+ gate_proj=gate_proj if loop_idx > 0 else None,
770
  **kwargs,
771
  )
 
772
 
 
 
 
773
  hidden_states = residual + hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
774
 
775
+ # Fully Connected
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
776
  residual = hidden_states
777
  hidden_states = self.post_attention_layernorm(hidden_states)
778
  hidden_states = self.mlp(hidden_states)
779
  hidden_states = residual + hidden_states
780
+ return hidden_states
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
781
 
782
 
783
+ @auto_docstring
784
  class IQuestLoopCoderPreTrainedModel(PreTrainedModel):
785
+ config: IQuestLoopCoderConfig
 
786
  base_model_prefix = "model"
787
  supports_gradient_checkpointing = True
788
  _no_split_modules = ["IQuestLoopCoderDecoderLayer"]
789
  _skip_keys_device_placement = ["past_key_values"]
790
+ _supports_flash_attn = True
791
+ _supports_sdpa = True
792
+ _supports_flex_attn = True
793
+
794
+ _can_compile_fullgraph = True
795
+ _supports_attention_backend = True
796
+ _can_record_outputs = {
797
+ "hidden_states": IQuestLoopCoderDecoderLayer,
798
+ "attentions": IQuestLoopCoderAttention,
799
+ }
800
 
801
+ # Important for inference with `device_map` / low_cpu_mem_usage:
802
+ # Avoid initializing parameters that are not present in the checkpoint.
803
+ # Those should keep their constructor-time initialization (e.g. zeros for LoopGateProjection),
804
+ # instead of being materialized from meta/empty tensors which can contain NaNs.
805
+ def _init_weights(self, module: nn.Module) -> None:
806
+ return
 
 
 
 
807
 
808
 
809
+ class IQuestLoopCoderRotaryEmbedding(nn.Module):
810
+ def __init__(self, config: IQuestLoopCoderConfig, device=None):
811
+ super().__init__()
812
+ # BC: "rope_type" was originally "type"
813
+ if hasattr(config, "rope_scaling") and isinstance(config.rope_scaling, dict):
814
+ self.rope_type = config.rope_scaling.get(
815
+ "rope_type", config.rope_scaling.get("type")
816
+ )
817
+ else:
818
+ self.rope_type = "default"
819
+ self.max_seq_len_cached = config.max_position_embeddings
820
+ self.original_max_seq_len = config.max_position_embeddings
821
+
822
+ self.config = config
823
+ self.rope_init_fn = ROPE_INIT_FUNCTIONS[self.rope_type]
824
+
825
+ inv_freq, self.attention_scaling = self.rope_init_fn(self.config, device)
826
+ self.register_buffer("inv_freq", inv_freq, persistent=False)
827
+ self.original_inv_freq = self.inv_freq
828
+
829
+ @torch.no_grad()
830
+ @dynamic_rope_update # power user: used with advanced RoPE types (e.g. dynamic rope)
831
+ def forward(self, x, position_ids):
832
+ inv_freq_expanded = (
833
+ self.inv_freq[None, :, None]
834
+ .float()
835
+ .expand(position_ids.shape[0], -1, 1)
836
+ .to(x.device)
837
+ )
838
+ position_ids_expanded = position_ids[:, None, :].float()
839
+
840
+ device_type = (
841
+ x.device.type
842
+ if isinstance(x.device.type, str) and x.device.type != "mps"
843
+ else "cpu"
844
+ )
845
+ with torch.autocast(device_type=device_type, enabled=False): # Force float32
846
+ freqs = (
847
+ inv_freq_expanded.float() @ position_ids_expanded.float()
848
+ ).transpose(1, 2)
849
+ emb = torch.cat((freqs, freqs), dim=-1)
850
+ cos = emb.cos() * self.attention_scaling
851
+ sin = emb.sin() * self.attention_scaling
852
+
853
+ return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
854
+
855
+
856
+ @auto_docstring
857
  class IQuestLoopCoderModel(IQuestLoopCoderPreTrainedModel):
 
 
858
  def __init__(self, config: IQuestLoopCoderConfig):
859
  super().__init__(config)
860
  self.padding_idx = config.pad_token_id
861
  self.vocab_size = config.vocab_size
862
+
863
+ self.embed_tokens = nn.Embedding(
864
+ config.vocab_size, config.hidden_size, self.padding_idx
865
+ )
866
+ self.layers = nn.ModuleList(
867
+ [
868
+ IQuestLoopCoderDecoderLayer(config, layer_idx)
869
+ for layer_idx in range(config.num_hidden_layers)
870
+ ]
871
+ )
872
  self.norm = IQuestLoopCoderRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
873
+ self.rotary_emb = IQuestLoopCoderRotaryEmbedding(config=config)
874
+ self.gradient_checkpointing = False
875
+ self.loop_num = getattr(self.config, "loop_num", 2)
876
+ self.loop_window_size = getattr(self.config, "loop_window_size", 64)
877
+
878
  # Gate projections for Loop 2+ (one per layer)
879
  self.gate_projections = nn.ModuleList([
880
  LoopGateProjection(config.num_attention_heads, config.head_dim)
881
  for _ in range(config.num_hidden_layers)
882
  ])
 
 
 
 
 
 
 
 
 
 
883
 
884
+ # Initialize weights and apply final processing
885
+ self.post_init()
886
 
887
+ @check_model_inputs
888
+ @auto_docstring
889
  def forward(
890
  self,
891
+ input_ids: Optional[torch.LongTensor] = None,
892
  attention_mask: Optional[torch.Tensor] = None,
893
  position_ids: Optional[torch.LongTensor] = None,
894
  past_key_values: Optional[Cache] = None,
895
  inputs_embeds: Optional[torch.FloatTensor] = None,
896
  use_cache: Optional[bool] = None,
 
 
 
897
  cache_position: Optional[torch.LongTensor] = None,
898
+ **kwargs: Unpack[TransformersKwargs],
899
+ ) -> BaseModelOutputWithPast:
900
+
901
+ if (input_ids is None) ^ (inputs_embeds is not None):
902
+ raise ValueError(
903
+ "You must specify exactly one of input_ids or inputs_embeds"
904
+ )
905
 
906
  if inputs_embeds is None:
907
  inputs_embeds = self.embed_tokens(input_ids)
908
 
909
+ if use_cache is None:
910
+ use_cache = self.config.use_cache
911
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
912
  if use_cache:
913
+ if needs_iquestloopcoder_cache(past_key_values):
914
+ past_key_values = IQuestLoopCoderCache(self.loop_window_size, self.config.num_hidden_layers, self.loop_num)
915
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
916
  if cache_position is None:
917
+ past_seen_tokens = (
918
+ past_key_values.get_seq_length() if past_key_values is not None else 0
919
+ )
920
+ cache_position = torch.arange(
921
+ past_seen_tokens,
922
+ past_seen_tokens + inputs_embeds.shape[1],
923
+ device=inputs_embeds.device,
924
+ )
925
+
926
  if position_ids is None:
927
  position_ids = cache_position.unsqueeze(0)
928
+
929
+ # It may already have been prepared by e.g. `generate`
930
+ if not isinstance(causal_mask_mapping := attention_mask, dict):
931
+ # Prepare mask arguments
932
+ mask_kwargs = {
933
+ "config": self.config,
934
+ "input_embeds": inputs_embeds,
935
+ "attention_mask": attention_mask,
936
+ "cache_position": cache_position,
937
+ "past_key_values": past_key_values,
938
+ "position_ids": position_ids,
939
+ }
940
+ # Create the full causal mask for all layers
941
+ # All layers use full_attention (no sliding window layers)
942
+ full_attention_mask = create_causal_mask(**mask_kwargs)
943
+ causal_mask_mapping = {
944
+ "full_attention": full_attention_mask,
945
+ }
946
+
 
 
 
 
947
  hidden_states = inputs_embeds
948
+
949
+ # create position embeddings to be shared across the decoder layers
950
+ position_embeddings = self.rotary_emb(hidden_states, position_ids)
951
+ hidden_states_list = []
952
+
953
+ for loop_idx in range(self.loop_num):
954
+ # For each loop, use the full_attention mask
955
+ # Loop 1: uses full_attention mask directly
956
+ # Loop 2+: forward_loop2 will create local mask internally, but uses full_attention mask for global attention
957
+ loop_attention_mask = causal_mask_mapping["full_attention"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
958
 
959
+ for layer_idx, decoder_layer in enumerate(self.layers[: self.config.num_hidden_layers]):
960
+ hidden_states = decoder_layer(
961
+ hidden_states,
962
+ loop_idx,
963
+ gate_proj=self.gate_projections[layer_idx] if loop_idx > 0 else None,
964
+ attention_mask=loop_attention_mask,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
965
  position_ids=position_ids,
966
+ past_key_value=past_key_values,
967
+ use_cache=use_cache,
968
+ cache_position=cache_position,
969
+ position_embeddings=position_embeddings,
970
+ **kwargs,
971
  )
972
+ if loop_idx < self.loop_num - 1:
973
+ hidden_states_list.append(hidden_states)
974
+
 
 
 
 
 
 
 
 
 
 
975
  hidden_states = self.norm(hidden_states)
976
+ hidden_states_list.append(hidden_states)
977
+
978
+ return (
979
+ BaseModelOutputWithPast(
980
+ last_hidden_state=hidden_states,
981
+ past_key_values=past_key_values if use_cache else None,
982
+ ),
983
+ hidden_states_list,
 
 
 
 
 
 
984
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
 
986
 
987
+ @auto_docstring
988
  class IQuestLoopCoderForCausalLM(IQuestLoopCoderPreTrainedModel, GenerationMixin):
 
989
  _tied_weights_keys = ["lm_head.weight"]
990
+ _tp_plan = {"lm_head": "colwise_rep"}
991
+ _pp_plan = {"lm_head": (["hidden_states"], ["logits"])}
992
 
993
  def __init__(self, config):
994
  super().__init__(config)
995
  self.model = IQuestLoopCoderModel(config)
996
  self.vocab_size = config.vocab_size
997
  self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
998
+
999
+ # 分块大小配置
1000
+ self.chunk_size = getattr(config, "chunk_size", 2) # 默认分块大小为2
1001
+
1002
  self.post_init()
1003
 
1004
  def get_input_embeddings(self):
 
1019
  def get_decoder(self):
1020
  return self.model
1021
 
1022
+ @can_return_tuple
1023
+ @auto_docstring
1024
  def forward(
1025
  self,
1026
+ input_ids: Optional[torch.LongTensor] = None,
1027
  attention_mask: Optional[torch.Tensor] = None,
1028
  position_ids: Optional[torch.LongTensor] = None,
1029
  past_key_values: Optional[Cache] = None,
1030
  inputs_embeds: Optional[torch.FloatTensor] = None,
1031
  labels: Optional[torch.LongTensor] = None,
1032
  use_cache: Optional[bool] = None,
 
 
 
1033
  cache_position: Optional[torch.LongTensor] = None,
1034
+ logits_to_keep: Union[int, torch.Tensor] = 0,
1035
+ **kwargs: Unpack[TransformersKwargs],
1036
+ ) -> CausalLMOutputWithPast:
 
1037
 
1038
+ outputs, hidden_states_list = self.model(
1039
  input_ids=input_ids,
1040
  attention_mask=attention_mask,
1041
  position_ids=position_ids,
1042
  past_key_values=past_key_values,
1043
  inputs_embeds=inputs_embeds,
1044
  use_cache=use_cache,
 
 
 
1045
  cache_position=cache_position,
1046
+ **kwargs,
1047
+ )
1048
+ slice_indices = (
1049
+ slice(-logits_to_keep, None)
1050
+ if isinstance(logits_to_keep, int)
1051
+ else logits_to_keep
1052
  )
1053
 
1054
+ def _select_token_positions(tensor: torch.Tensor) -> torch.Tensor:
1055
+ if isinstance(slice_indices, slice):
1056
+ return tensor[:, slice_indices, ...]
1057
+ if isinstance(slice_indices, torch.Tensor):
1058
+ return tensor.index_select(1, slice_indices.to(tensor.device))
1059
+ raise TypeError(
1060
+ f"Unsupported index type for logits_to_keep: {type(slice_indices)}"
1061
+ )
1062
+
1063
+ stacked_exit_pdf = None
1064
+
1065
+ expected_logits_cache: Optional[torch.Tensor] = None
1066
+
1067
+ def compute_expected_logits() -> Optional[torch.Tensor]:
1068
+ nonlocal expected_logits_cache
1069
+ if expected_logits_cache is not None:
1070
+ return expected_logits_cache
1071
+ if stacked_exit_pdf is None or not hidden_states_list:
1072
+ return None
1073
+ token_exit_pdf = _select_token_positions(stacked_exit_pdf)
1074
+ expected_logits = None
1075
+ for step_idx, hidden in enumerate(hidden_states_list):
1076
+ step_hidden = _select_token_positions(hidden)
1077
+ step_logits = self.lm_head(step_hidden)
1078
+ weight = (
1079
+ token_exit_pdf[..., step_idx].unsqueeze(-1).to(step_logits.dtype)
1080
+ )
1081
+ expected_logits = (
1082
+ step_logits * weight
1083
+ if expected_logits is None
1084
+ else expected_logits + step_logits * weight
1085
+ )
1086
+ expected_logits_cache = expected_logits
1087
+ return expected_logits_cache
1088
+
1089
+ logits: Optional[torch.Tensor] = None
1090
+ loss: Optional[torch.Tensor] = None
1091
+
1092
+ hidden_states = outputs.last_hidden_state
1093
  logits = self.lm_head(hidden_states)
1094
  logits = logits.float()
1095
 
 
1096
  if labels is not None:
1097
  shift_logits = logits[..., :-1, :].contiguous()
1098
  shift_labels = labels[..., 1:].contiguous()
 
1102
  shift_labels = shift_labels.to(shift_logits.device)
1103
  loss = loss_fct(shift_logits, shift_labels)
1104
 
1105
+ result = CausalLMOutputWithPast(
 
 
 
 
1106
  loss=loss,
1107
  logits=logits,
1108
  past_key_values=outputs.past_key_values,
 
1110
  attentions=outputs.attentions,
1111
  )
1112
 
1113
+ return result