ToastyPigeon commited on
Commit
549e5bd
Β·
verified Β·
1 Parent(s): 82a99b7

Upload modeling_loopstral.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. modeling_loopstral.py +56 -49
modeling_loopstral.py CHANGED
@@ -4,6 +4,7 @@
4
  # the file from the modular. If any change should be done, please apply the change to the
5
  # modular_mistral.py file directly. One of our CI enforces this.
6
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
 
7
  from typing import Callable, Optional, Union
8
 
9
  import torch
@@ -13,7 +14,7 @@ from torch.nn import CrossEntropyLoss
13
  #from transformers.modeling_utils import check_model_inputs
14
 
15
  from transformers.activations import ACT2FN
16
- from transformers.cache_utils import Cache, DynamicCache
17
  from transformers.generation import GenerationMixin
18
  from transformers.integrations import use_kernel_forward_from_hub
19
  from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
@@ -147,7 +148,7 @@ class MistralAttention(nn.Module):
147
  attention_mask: Optional[torch.Tensor],
148
  past_key_values: Optional[Cache] = None,
149
  cache_position: Optional[torch.LongTensor] = None,
150
- update_cache: Optional[bool] = True,
151
  **kwargs: Unpack[FlashAttentionKwargs],
152
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
153
  input_shape = hidden_states.shape[:-1]
@@ -163,20 +164,10 @@ class MistralAttention(nn.Module):
163
  if past_key_values is not None:
164
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
165
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
166
-
167
- # --- START DEBUGGING CODE ---
168
- # Add these two lines to see what the object is and what's inside it.
169
- #print(f"DEBUG: Type of cache object: {type(past_key_values)}")
170
- #print(f"DEBUG: Attributes of cache object: {dir(past_key_values)}")
171
- # --- END DEBUGGING CODE ---
172
-
173
- if update_cache:
174
- key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
175
- else:
176
- k_cache, v_cache = past_key_values[self.layer_idx]
177
- if k_cache is not None:
178
- key_states = torch.cat([k_cache, key_states], dim=2)
179
- value_states = torch.cat([v_cache, value_states], dim=2)
180
 
181
  attention_interface: Callable = eager_attention_forward
182
  if self.config._attn_implementation != "eager":
@@ -239,7 +230,7 @@ class MistralDecoderLayer(GradientCheckpointingLayer):
239
  use_cache: Optional[bool] = False,
240
  cache_position: Optional[torch.LongTensor] = None,
241
  position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
242
- update_cache: Optional[bool] = True,
243
  **kwargs: Unpack[TransformersKwargs],
244
  ) -> torch.Tensor:
245
  residual = hidden_states
@@ -253,7 +244,7 @@ class MistralDecoderLayer(GradientCheckpointingLayer):
253
  use_cache=use_cache,
254
  cache_position=cache_position,
255
  position_embeddings=position_embeddings,
256
- update_cache=update_cache,
257
  **kwargs,
258
  )
259
  hidden_states = residual + hidden_states
@@ -321,6 +312,29 @@ class MistralRotaryEmbedding(nn.Module):
321
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
322
 
323
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
324
  @auto_docstring
325
  class LoopstralModel(MistralPreTrainedModel):
326
  def __init__(self, config: LoopstralConfig):
@@ -336,6 +350,11 @@ class LoopstralModel(MistralPreTrainedModel):
336
  self.rotary_emb = MistralRotaryEmbedding(config=config)
337
  self.gradient_checkpointing = False
338
 
 
 
 
 
 
339
  # Initialize weights and apply final processing
340
  self.post_init()
341
 
@@ -358,8 +377,18 @@ class LoopstralModel(MistralPreTrainedModel):
358
  if inputs_embeds is None:
359
  inputs_embeds = self.embed_tokens(input_ids)
360
 
361
- if use_cache and past_key_values is None:
362
- past_key_values = DynamicCache(config=self.config)
 
 
 
 
 
 
 
 
 
 
363
 
364
  if cache_position is None:
365
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
@@ -382,34 +411,12 @@ class LoopstralModel(MistralPreTrainedModel):
382
 
383
  hidden_states = inputs_embeds
384
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
385
- #***
386
- #Create the loop sequence!
387
- #***
388
- l_seq = []
389
- #print(self.config.layer_sequence)
390
- for item in self.config.layer_sequence:
391
- if isinstance(item, int):
392
- # Single layer index: 5 -> [5]
393
- l_seq.append(item)
394
- elif isinstance(item, list):
395
- if len(item) == 2:
396
- # Range without repeat: [4, 20] -> range(4, 20)
397
- start, end = item
398
- l_seq += list(range(start, min(end, self.config.num_hidden_layers)))
399
- elif len(item) == 3:
400
- # Range with repeat: [4, 20, 2] -> range(4, 20) repeated 2 times
401
- start, end, repeats = item
402
- l_seq += list(range(start, min(end, self.config.num_hidden_layers))) * repeats
403
- else:
404
- raise ValueError(f"Invalid layer_sequence item: {item}. Expected int, [start, end], or [start, end, repeats]")
405
- else:
406
- raise ValueError(f"Invalid layer_sequence item type: {type(item)}. Expected int or list.")
407
- #print(f"DEBUG: Layer sequence {l_seq}")
408
-
409
- last_visit_map = {layer_idx: i for i, layer_idx in enumerate(l_seq)}
410
- for i, layer in enumerate(l_seq):
411
- should_update_cache = use_cache and (last_visit_map[layer] == i)
412
- decoder_layer = self.layers[layer]
413
  hidden_states = decoder_layer(
414
  hidden_states,
415
  attention_mask=causal_mask,
@@ -418,7 +425,7 @@ class LoopstralModel(MistralPreTrainedModel):
418
  use_cache=use_cache,
419
  cache_position=cache_position,
420
  position_embeddings=position_embeddings,
421
- update_cache=should_update_cache,
422
  **kwargs,
423
  )
424
  hidden_states = self.norm(hidden_states)
 
4
  # the file from the modular. If any change should be done, please apply the change to the
5
  # modular_mistral.py file directly. One of our CI enforces this.
6
  # 🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨🚨
7
+ import copy
8
  from typing import Callable, Optional, Union
9
 
10
  import torch
 
14
  #from transformers.modeling_utils import check_model_inputs
15
 
16
  from transformers.activations import ACT2FN
17
+ from transformers.cache_utils import Cache, DynamicCache, DynamicLayer
18
  from transformers.generation import GenerationMixin
19
  from transformers.integrations import use_kernel_forward_from_hub
20
  from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
 
148
  attention_mask: Optional[torch.Tensor],
149
  past_key_values: Optional[Cache] = None,
150
  cache_position: Optional[torch.LongTensor] = None,
151
+ cache_slot_idx: Optional[int] = None,
152
  **kwargs: Unpack[FlashAttentionKwargs],
153
  ) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
154
  input_shape = hidden_states.shape[:-1]
 
164
  if past_key_values is not None:
165
  # sin and cos are specific to RoPE models; cache_position needed for the static cache
166
  cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
167
+ # Use cache_slot_idx (position in layer sequence) instead of layer_idx
168
+ # This allows each visit to a repeated layer to have its own cache slot
169
+ slot_idx = cache_slot_idx if cache_slot_idx is not None else self.layer_idx
170
+ key_states, value_states = past_key_values.update(key_states, value_states, slot_idx, cache_kwargs)
 
 
 
 
 
 
 
 
 
 
171
 
172
  attention_interface: Callable = eager_attention_forward
173
  if self.config._attn_implementation != "eager":
 
230
  use_cache: Optional[bool] = False,
231
  cache_position: Optional[torch.LongTensor] = None,
232
  position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
233
+ cache_slot_idx: Optional[int] = None,
234
  **kwargs: Unpack[TransformersKwargs],
235
  ) -> torch.Tensor:
236
  residual = hidden_states
 
244
  use_cache=use_cache,
245
  cache_position=cache_position,
246
  position_embeddings=position_embeddings,
247
+ cache_slot_idx=cache_slot_idx,
248
  **kwargs,
249
  )
250
  hidden_states = residual + hidden_states
 
312
  return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
313
 
314
 
315
+ def _expand_layer_sequence(layer_sequence, num_hidden_layers):
316
+ """Expand layer_sequence config into a flat list of layer indices."""
317
+ l_seq = []
318
+ for item in layer_sequence:
319
+ if isinstance(item, int):
320
+ # Single layer index: 5 -> [5]
321
+ l_seq.append(item)
322
+ elif isinstance(item, list):
323
+ if len(item) == 2:
324
+ # Range without repeat: [4, 20] -> range(4, 20)
325
+ start, end = item
326
+ l_seq += list(range(start, min(end, num_hidden_layers)))
327
+ elif len(item) == 3:
328
+ # Range with repeat: [4, 20, 2] -> range(4, 20) repeated 2 times
329
+ start, end, repeats = item
330
+ l_seq += list(range(start, min(end, num_hidden_layers))) * repeats
331
+ else:
332
+ raise ValueError(f"Invalid layer_sequence item: {item}. Expected int, [start, end], or [start, end, repeats]")
333
+ else:
334
+ raise ValueError(f"Invalid layer_sequence item type: {type(item)}. Expected int or list.")
335
+ return l_seq
336
+
337
+
338
  @auto_docstring
339
  class LoopstralModel(MistralPreTrainedModel):
340
  def __init__(self, config: LoopstralConfig):
 
350
  self.rotary_emb = MistralRotaryEmbedding(config=config)
351
  self.gradient_checkpointing = False
352
 
353
+ # Pre-compute the expanded layer sequence for the looping mechanism
354
+ self._layer_sequence = _expand_layer_sequence(config.layer_sequence, config.num_hidden_layers)
355
+ # Number of cache slots needed (one per position in layer sequence)
356
+ self._num_cache_slots = len(self._layer_sequence)
357
+
358
  # Initialize weights and apply final processing
359
  self.post_init()
360
 
 
377
  if inputs_embeds is None:
378
  inputs_embeds = self.embed_tokens(input_ids)
379
 
380
+ if use_cache:
381
+ if past_key_values is None:
382
+ # Create cache with enough slots for the full layer sequence
383
+ # (more than num_hidden_layers if layers are repeated)
384
+ cache_config = copy.copy(self.config)
385
+ cache_config.num_hidden_layers = self._num_cache_slots
386
+ past_key_values = DynamicCache(config=cache_config)
387
+ elif isinstance(past_key_values, DynamicCache) and len(past_key_values.layers) < self._num_cache_slots:
388
+ # Cache was created externally (e.g., by generate()) with fewer slots
389
+ # Extend it to have enough slots for our layer sequence
390
+ while len(past_key_values.layers) < self._num_cache_slots:
391
+ past_key_values.layers.append(DynamicLayer())
392
 
393
  if cache_position is None:
394
  past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
 
411
 
412
  hidden_states = inputs_embeds
413
  position_embeddings = self.rotary_emb(hidden_states, position_ids)
414
+
415
+ # Execute layers in the configured sequence
416
+ # Each position in the sequence gets its own cache slot, allowing
417
+ # repeated layers to maintain separate KV caches for each visit
418
+ for cache_slot_idx, layer_idx in enumerate(self._layer_sequence):
419
+ decoder_layer = self.layers[layer_idx]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
420
  hidden_states = decoder_layer(
421
  hidden_states,
422
  attention_mask=causal_mask,
 
425
  use_cache=use_cache,
426
  cache_position=cache_position,
427
  position_embeddings=position_embeddings,
428
+ cache_slot_idx=cache_slot_idx,
429
  **kwargs,
430
  )
431
  hidden_states = self.norm(hidden_states)