Upload modeling_loopstral.py with huggingface_hub
Browse files- modeling_loopstral.py +56 -49
modeling_loopstral.py
CHANGED
|
@@ -4,6 +4,7 @@
|
|
| 4 |
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
# modular_mistral.py file directly. One of our CI enforces this.
|
| 6 |
# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
|
|
|
|
| 7 |
from typing import Callable, Optional, Union
|
| 8 |
|
| 9 |
import torch
|
|
@@ -13,7 +14,7 @@ from torch.nn import CrossEntropyLoss
|
|
| 13 |
#from transformers.modeling_utils import check_model_inputs
|
| 14 |
|
| 15 |
from transformers.activations import ACT2FN
|
| 16 |
-
from transformers.cache_utils import Cache, DynamicCache
|
| 17 |
from transformers.generation import GenerationMixin
|
| 18 |
from transformers.integrations import use_kernel_forward_from_hub
|
| 19 |
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
@@ -147,7 +148,7 @@ class MistralAttention(nn.Module):
|
|
| 147 |
attention_mask: Optional[torch.Tensor],
|
| 148 |
past_key_values: Optional[Cache] = None,
|
| 149 |
cache_position: Optional[torch.LongTensor] = None,
|
| 150 |
-
|
| 151 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 152 |
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 153 |
input_shape = hidden_states.shape[:-1]
|
|
@@ -163,20 +164,10 @@ class MistralAttention(nn.Module):
|
|
| 163 |
if past_key_values is not None:
|
| 164 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 165 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 166 |
-
|
| 167 |
-
#
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
#print(f"DEBUG: Attributes of cache object: {dir(past_key_values)}")
|
| 171 |
-
# --- END DEBUGGING CODE ---
|
| 172 |
-
|
| 173 |
-
if update_cache:
|
| 174 |
-
key_states, value_states = past_key_values.update(key_states, value_states, self.layer_idx, cache_kwargs)
|
| 175 |
-
else:
|
| 176 |
-
k_cache, v_cache = past_key_values[self.layer_idx]
|
| 177 |
-
if k_cache is not None:
|
| 178 |
-
key_states = torch.cat([k_cache, key_states], dim=2)
|
| 179 |
-
value_states = torch.cat([v_cache, value_states], dim=2)
|
| 180 |
|
| 181 |
attention_interface: Callable = eager_attention_forward
|
| 182 |
if self.config._attn_implementation != "eager":
|
|
@@ -239,7 +230,7 @@ class MistralDecoderLayer(GradientCheckpointingLayer):
|
|
| 239 |
use_cache: Optional[bool] = False,
|
| 240 |
cache_position: Optional[torch.LongTensor] = None,
|
| 241 |
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 242 |
-
|
| 243 |
**kwargs: Unpack[TransformersKwargs],
|
| 244 |
) -> torch.Tensor:
|
| 245 |
residual = hidden_states
|
|
@@ -253,7 +244,7 @@ class MistralDecoderLayer(GradientCheckpointingLayer):
|
|
| 253 |
use_cache=use_cache,
|
| 254 |
cache_position=cache_position,
|
| 255 |
position_embeddings=position_embeddings,
|
| 256 |
-
|
| 257 |
**kwargs,
|
| 258 |
)
|
| 259 |
hidden_states = residual + hidden_states
|
|
@@ -321,6 +312,29 @@ class MistralRotaryEmbedding(nn.Module):
|
|
| 321 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 322 |
|
| 323 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
@auto_docstring
|
| 325 |
class LoopstralModel(MistralPreTrainedModel):
|
| 326 |
def __init__(self, config: LoopstralConfig):
|
|
@@ -336,6 +350,11 @@ class LoopstralModel(MistralPreTrainedModel):
|
|
| 336 |
self.rotary_emb = MistralRotaryEmbedding(config=config)
|
| 337 |
self.gradient_checkpointing = False
|
| 338 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 339 |
# Initialize weights and apply final processing
|
| 340 |
self.post_init()
|
| 341 |
|
|
@@ -358,8 +377,18 @@ class LoopstralModel(MistralPreTrainedModel):
|
|
| 358 |
if inputs_embeds is None:
|
| 359 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 360 |
|
| 361 |
-
if use_cache
|
| 362 |
-
past_key_values
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 363 |
|
| 364 |
if cache_position is None:
|
| 365 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
@@ -382,34 +411,12 @@ class LoopstralModel(MistralPreTrainedModel):
|
|
| 382 |
|
| 383 |
hidden_states = inputs_embeds
|
| 384 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 385 |
-
|
| 386 |
-
#
|
| 387 |
-
#
|
| 388 |
-
|
| 389 |
-
|
| 390 |
-
|
| 391 |
-
if isinstance(item, int):
|
| 392 |
-
# Single layer index: 5 -> [5]
|
| 393 |
-
l_seq.append(item)
|
| 394 |
-
elif isinstance(item, list):
|
| 395 |
-
if len(item) == 2:
|
| 396 |
-
# Range without repeat: [4, 20] -> range(4, 20)
|
| 397 |
-
start, end = item
|
| 398 |
-
l_seq += list(range(start, min(end, self.config.num_hidden_layers)))
|
| 399 |
-
elif len(item) == 3:
|
| 400 |
-
# Range with repeat: [4, 20, 2] -> range(4, 20) repeated 2 times
|
| 401 |
-
start, end, repeats = item
|
| 402 |
-
l_seq += list(range(start, min(end, self.config.num_hidden_layers))) * repeats
|
| 403 |
-
else:
|
| 404 |
-
raise ValueError(f"Invalid layer_sequence item: {item}. Expected int, [start, end], or [start, end, repeats]")
|
| 405 |
-
else:
|
| 406 |
-
raise ValueError(f"Invalid layer_sequence item type: {type(item)}. Expected int or list.")
|
| 407 |
-
#print(f"DEBUG: Layer sequence {l_seq}")
|
| 408 |
-
|
| 409 |
-
last_visit_map = {layer_idx: i for i, layer_idx in enumerate(l_seq)}
|
| 410 |
-
for i, layer in enumerate(l_seq):
|
| 411 |
-
should_update_cache = use_cache and (last_visit_map[layer] == i)
|
| 412 |
-
decoder_layer = self.layers[layer]
|
| 413 |
hidden_states = decoder_layer(
|
| 414 |
hidden_states,
|
| 415 |
attention_mask=causal_mask,
|
|
@@ -418,7 +425,7 @@ class LoopstralModel(MistralPreTrainedModel):
|
|
| 418 |
use_cache=use_cache,
|
| 419 |
cache_position=cache_position,
|
| 420 |
position_embeddings=position_embeddings,
|
| 421 |
-
|
| 422 |
**kwargs,
|
| 423 |
)
|
| 424 |
hidden_states = self.norm(hidden_states)
|
|
|
|
| 4 |
# the file from the modular. If any change should be done, please apply the change to the
|
| 5 |
# modular_mistral.py file directly. One of our CI enforces this.
|
| 6 |
# π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨π¨
|
| 7 |
+
import copy
|
| 8 |
from typing import Callable, Optional, Union
|
| 9 |
|
| 10 |
import torch
|
|
|
|
| 14 |
#from transformers.modeling_utils import check_model_inputs
|
| 15 |
|
| 16 |
from transformers.activations import ACT2FN
|
| 17 |
+
from transformers.cache_utils import Cache, DynamicCache, DynamicLayer
|
| 18 |
from transformers.generation import GenerationMixin
|
| 19 |
from transformers.integrations import use_kernel_forward_from_hub
|
| 20 |
from transformers.masking_utils import create_causal_mask, create_sliding_window_causal_mask
|
|
|
|
| 148 |
attention_mask: Optional[torch.Tensor],
|
| 149 |
past_key_values: Optional[Cache] = None,
|
| 150 |
cache_position: Optional[torch.LongTensor] = None,
|
| 151 |
+
cache_slot_idx: Optional[int] = None,
|
| 152 |
**kwargs: Unpack[FlashAttentionKwargs],
|
| 153 |
) -> tuple[torch.Tensor, Optional[torch.Tensor]]:
|
| 154 |
input_shape = hidden_states.shape[:-1]
|
|
|
|
| 164 |
if past_key_values is not None:
|
| 165 |
# sin and cos are specific to RoPE models; cache_position needed for the static cache
|
| 166 |
cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position}
|
| 167 |
+
# Use cache_slot_idx (position in layer sequence) instead of layer_idx
|
| 168 |
+
# This allows each visit to a repeated layer to have its own cache slot
|
| 169 |
+
slot_idx = cache_slot_idx if cache_slot_idx is not None else self.layer_idx
|
| 170 |
+
key_states, value_states = past_key_values.update(key_states, value_states, slot_idx, cache_kwargs)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
|
| 172 |
attention_interface: Callable = eager_attention_forward
|
| 173 |
if self.config._attn_implementation != "eager":
|
|
|
|
| 230 |
use_cache: Optional[bool] = False,
|
| 231 |
cache_position: Optional[torch.LongTensor] = None,
|
| 232 |
position_embeddings: Optional[tuple[torch.Tensor, torch.Tensor]] = None, # necessary, but kept here for BC
|
| 233 |
+
cache_slot_idx: Optional[int] = None,
|
| 234 |
**kwargs: Unpack[TransformersKwargs],
|
| 235 |
) -> torch.Tensor:
|
| 236 |
residual = hidden_states
|
|
|
|
| 244 |
use_cache=use_cache,
|
| 245 |
cache_position=cache_position,
|
| 246 |
position_embeddings=position_embeddings,
|
| 247 |
+
cache_slot_idx=cache_slot_idx,
|
| 248 |
**kwargs,
|
| 249 |
)
|
| 250 |
hidden_states = residual + hidden_states
|
|
|
|
| 312 |
return cos.to(dtype=x.dtype), sin.to(dtype=x.dtype)
|
| 313 |
|
| 314 |
|
| 315 |
+
def _expand_layer_sequence(layer_sequence, num_hidden_layers):
|
| 316 |
+
"""Expand layer_sequence config into a flat list of layer indices."""
|
| 317 |
+
l_seq = []
|
| 318 |
+
for item in layer_sequence:
|
| 319 |
+
if isinstance(item, int):
|
| 320 |
+
# Single layer index: 5 -> [5]
|
| 321 |
+
l_seq.append(item)
|
| 322 |
+
elif isinstance(item, list):
|
| 323 |
+
if len(item) == 2:
|
| 324 |
+
# Range without repeat: [4, 20] -> range(4, 20)
|
| 325 |
+
start, end = item
|
| 326 |
+
l_seq += list(range(start, min(end, num_hidden_layers)))
|
| 327 |
+
elif len(item) == 3:
|
| 328 |
+
# Range with repeat: [4, 20, 2] -> range(4, 20) repeated 2 times
|
| 329 |
+
start, end, repeats = item
|
| 330 |
+
l_seq += list(range(start, min(end, num_hidden_layers))) * repeats
|
| 331 |
+
else:
|
| 332 |
+
raise ValueError(f"Invalid layer_sequence item: {item}. Expected int, [start, end], or [start, end, repeats]")
|
| 333 |
+
else:
|
| 334 |
+
raise ValueError(f"Invalid layer_sequence item type: {type(item)}. Expected int or list.")
|
| 335 |
+
return l_seq
|
| 336 |
+
|
| 337 |
+
|
| 338 |
@auto_docstring
|
| 339 |
class LoopstralModel(MistralPreTrainedModel):
|
| 340 |
def __init__(self, config: LoopstralConfig):
|
|
|
|
| 350 |
self.rotary_emb = MistralRotaryEmbedding(config=config)
|
| 351 |
self.gradient_checkpointing = False
|
| 352 |
|
| 353 |
+
# Pre-compute the expanded layer sequence for the looping mechanism
|
| 354 |
+
self._layer_sequence = _expand_layer_sequence(config.layer_sequence, config.num_hidden_layers)
|
| 355 |
+
# Number of cache slots needed (one per position in layer sequence)
|
| 356 |
+
self._num_cache_slots = len(self._layer_sequence)
|
| 357 |
+
|
| 358 |
# Initialize weights and apply final processing
|
| 359 |
self.post_init()
|
| 360 |
|
|
|
|
| 377 |
if inputs_embeds is None:
|
| 378 |
inputs_embeds = self.embed_tokens(input_ids)
|
| 379 |
|
| 380 |
+
if use_cache:
|
| 381 |
+
if past_key_values is None:
|
| 382 |
+
# Create cache with enough slots for the full layer sequence
|
| 383 |
+
# (more than num_hidden_layers if layers are repeated)
|
| 384 |
+
cache_config = copy.copy(self.config)
|
| 385 |
+
cache_config.num_hidden_layers = self._num_cache_slots
|
| 386 |
+
past_key_values = DynamicCache(config=cache_config)
|
| 387 |
+
elif isinstance(past_key_values, DynamicCache) and len(past_key_values.layers) < self._num_cache_slots:
|
| 388 |
+
# Cache was created externally (e.g., by generate()) with fewer slots
|
| 389 |
+
# Extend it to have enough slots for our layer sequence
|
| 390 |
+
while len(past_key_values.layers) < self._num_cache_slots:
|
| 391 |
+
past_key_values.layers.append(DynamicLayer())
|
| 392 |
|
| 393 |
if cache_position is None:
|
| 394 |
past_seen_tokens = past_key_values.get_seq_length() if past_key_values is not None else 0
|
|
|
|
| 411 |
|
| 412 |
hidden_states = inputs_embeds
|
| 413 |
position_embeddings = self.rotary_emb(hidden_states, position_ids)
|
| 414 |
+
|
| 415 |
+
# Execute layers in the configured sequence
|
| 416 |
+
# Each position in the sequence gets its own cache slot, allowing
|
| 417 |
+
# repeated layers to maintain separate KV caches for each visit
|
| 418 |
+
for cache_slot_idx, layer_idx in enumerate(self._layer_sequence):
|
| 419 |
+
decoder_layer = self.layers[layer_idx]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 420 |
hidden_states = decoder_layer(
|
| 421 |
hidden_states,
|
| 422 |
attention_mask=causal_mask,
|
|
|
|
| 425 |
use_cache=use_cache,
|
| 426 |
cache_position=cache_position,
|
| 427 |
position_embeddings=position_embeddings,
|
| 428 |
+
cache_slot_idx=cache_slot_idx,
|
| 429 |
**kwargs,
|
| 430 |
)
|
| 431 |
hidden_states = self.norm(hidden_states)
|