| """HuggingFace causal-LM wrapper for SHRAM. |
| |
| ShramForCausalLM is the HuggingFace-facing language-model boundary for SHRAM. |
| It owns token embedding lookup, LM-head projection, wrapper-level next-token |
| cross-entropy loss, config-controlled tied embeddings, and generation/cache |
| orchestration at the wrapper boundary. |
| |
| The backbone remains a pure transformer stack. ShramModel accepts pre-embedded |
| hidden states together with current position IDs, a current active mask, and an |
| optional ShramCache. It has no knowledge of token IDs, vocabulary projection, |
| or causal-LM loss. |
| |
| HuggingFace generation reaches this wrapper with two different tensor |
| conventions: |
| |
| - ``position_ids`` is a current-step tensor. GenerationMixin updates the total |
| sequence state between steps, then slices position-bearing tensors back down |
| before calling ``forward()``. |
| - ``attention_mask`` is a full 2D mask over the total sequence so far. This |
| wrapper slices its recent chunk to produce the current semantic liveness mask |
| expected by the backbone. |
| |
| Generation-created caches are handled in ``_prepare_cache_for_generation``. |
| That hook ensures HuggingFace generation uses ShramCache rather than a generic |
| dynamic cache. The direct ``forward()`` path does not silently create caches; |
| when ``use_cache=True`` it expects a truthful ShramCache to have been supplied. |
| """ |
|
|
| from dataclasses import dataclass |
| from typing import Any |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import GenerationMixin, PreTrainedModel |
| from transformers.cache_utils import Cache |
| from transformers.generation.configuration_utils import GenerationMode |
| from transformers.modeling_outputs import CausalLMOutputWithPast |
|
|
| from .__cache__shram_cache import ShramCache |
| from .configuration import ShramConfig |
| from .model import ShramModel |
|
|
|
|
| @dataclass |
| class ShramCausalLMOutput(CausalLMOutputWithPast): |
| """SHRAM causal-LM wrapper output. |
| |
| This subclasses HuggingFace's standard ``CausalLMOutputWithPast``. |
| Dataclass inheritance is sufficient here: all standard causal-LM fields and |
| ModelOutput behavior are inherited from the parent, and this subclass adds |
| only the SHRAM-specific wrapper outputs. |
| """ |
|
|
| ce_loss: torch.FloatTensor | None = None |
| load_balance_loss: torch.FloatTensor | None = None |
| max_vio: torch.FloatTensor | None = None |
|
|
|
|
| class ShramForCausalLM(PreTrainedModel, GenerationMixin): |
| """HuggingFace-facing causal language model wrapper for SHRAM. |
| |
| Owns token embeddings, LM-head projection, wrapper-level shifted CE loss, |
| tied embedding configuration, and generation/cache boundary behavior. |
| Delegates all transformer computation to ``ShramModel``. |
| |
| Args: |
| config: SHRAM model configuration. |
| """ |
|
|
| config_class = ShramConfig |
| base_model_prefix = "model" |
| _no_split_modules = ["DecoderLayer"] |
| supports_gradient_checkpointing = True |
|
|
| def __init__(self, config: ShramConfig) -> None: |
| super().__init__(config) |
| self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size) |
| self.model = ShramModel(config) |
| self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| self._configure_tied_embeddings() |
| self.post_init() |
|
|
| def _configure_tied_embeddings(self) -> None: |
| """Apply config-controlled tied embedding behavior on this instance.""" |
| if self.config.tie_word_embeddings: |
| self.lm_head.weight = self.embed_tokens.weight |
| self._tied_weights_keys = { |
| "lm_head.weight": "embed_tokens.weight", |
| } |
| else: |
| self._tied_weights_keys = {} |
|
|
| def num_mosrah_parameters(self) -> int: |
| """Return the total number of trainable parameters belonging to MoSRAH layers. |
| |
| Aggregates across all decoder layers. Excludes sliding-window path parameters, |
| FFN parameters, norms, and embeddings. Use this for experimental plotting of |
| MoSRAH parameter count versus performance. |
| |
| Returns: |
| Total count of trainable MoSRAH parameters. |
| """ |
| return self.model.num_mosrah_parameters() |
|
|
| def get_input_embeddings(self) -> nn.Embedding: |
| """Return the token embedding matrix.""" |
| return self.embed_tokens |
|
|
| def set_input_embeddings(self, value: nn.Embedding) -> None: |
| """Replace the token embedding matrix.""" |
| self.embed_tokens = value |
| self._configure_tied_embeddings() |
|
|
| def get_output_embeddings(self) -> nn.Linear: |
| """Return the LM head.""" |
| return self.lm_head |
|
|
| def set_output_embeddings(self, value: nn.Linear) -> None: |
| """Replace the LM head.""" |
| self.lm_head = value |
| self._configure_tied_embeddings() |
|
|
| def _build_shram_cache( |
| self, |
| batch_size: int, |
| device: torch.device, |
| ) -> ShramCache: |
| """Construct a fresh top-level SHRAM cache.""" |
| return ShramCache( |
| num_hidden_layers=self.config.num_hidden_layers, |
| sliding_window=self.config.window_size, |
| num_local_heads=self.config.num_sliding_window_heads, |
| local_head_dim=self.config.head_dim, |
| num_mosrah_heads=self.config.num_mosrah_heads, |
| mosrah_head_dim=self.config.head_dim, |
| batch_size=batch_size, |
| device=device, |
| ) |
|
|
| def _validate_generation_cache_request( |
| self, |
| generation_config: Any, |
| model_kwargs: dict[str, Any], |
| generation_mode: GenerationMode, |
| ) -> None: |
| """Validate SHRAM's generation-side cache policy.""" |
| if generation_mode in { |
| GenerationMode.ASSISTED_GENERATION, |
| GenerationMode.CONTRASTIVE_SEARCH, |
| }: |
| raise NotImplementedError( |
| "ShramForCausalLM does not currently support assisted generation " |
| "or contrastive search because ShramCache does not support crop()." |
| ) |
|
|
| user_defined_cache = model_kwargs.get("past_key_values") |
| if user_defined_cache is not None: |
| if generation_config.cache_implementation is not None: |
| raise ValueError( |
| "Passing both `cache_implementation` and `past_key_values` " |
| "is unsupported. Please use only one." |
| ) |
| if isinstance(user_defined_cache, tuple): |
| raise ValueError( |
| "Passing a tuple of `past_key_values` is not supported. " |
| "Please use a `ShramCache` instance." |
| ) |
| if not isinstance(user_defined_cache, ShramCache): |
| raise TypeError( |
| "ShramForCausalLM requires `past_key_values` to be a " |
| "`ShramCache` instance." |
| ) |
|
|
| if ( |
| user_defined_cache is None |
| and generation_config.use_cache |
| and generation_config.cache_implementation is not None |
| ): |
| raise ValueError( |
| "ShramForCausalLM does not support `cache_implementation`. " |
| "Generation-created caches must be `ShramCache` objects." |
| ) |
|
|
| def _prepare_cache_for_generation( |
| self, |
| generation_config: Any, |
| model_kwargs: dict[str, Any], |
| generation_mode: GenerationMode, |
| batch_size: int, |
| max_cache_length: int, |
| ) -> None: |
| """Ensure HuggingFace generation uses ShramCache. |
| |
| This is the SHRAM-specific generation hook. The rest of the default |
| generation plumbing is kept intact as much as possible. |
| |
| Args: |
| generation_config: Active generation configuration. |
| model_kwargs: Generation kwargs, updated in place. |
| generation_mode: HuggingFace generation mode. |
| batch_size: Effective generation batch size. |
| max_cache_length: Requested cache length. Accepted but unused here. |
| """ |
| self._validate_generation_cache_request( |
| generation_config=generation_config, |
| model_kwargs=model_kwargs, |
| generation_mode=generation_mode, |
| ) |
|
|
| if model_kwargs.get("past_key_values") is not None: |
| return |
|
|
| if not generation_config.use_cache: |
| return |
|
|
| num_repeats = max( |
| generation_config.num_beams or 1, |
| generation_config.num_return_sequences or 1, |
| ) |
| model_kwargs["past_key_values"] = self._build_shram_cache( |
| batch_size=batch_size*num_repeats, |
| device=self.embed_tokens.weight.device, |
| ) |
|
|
| def _reorder_cache( |
| self, |
| past_key_values: Cache, |
| beam_idx: torch.Tensor, |
| ) -> Cache: |
| """Reorder the cache in place for beam search.""" |
| past_key_values.reorder_cache(beam_idx) |
| return past_key_values |
|
|
| def _validate_input_ids(self, input_ids: torch.Tensor) -> None: |
| """Validate token IDs at the wrapper boundary.""" |
| if input_ids.ndim != 2: |
| raise ValueError("input_ids must have shape (batch, seq_len).") |
| if input_ids.shape[1] == 0: |
| raise ValueError("input_ids sequence length must be nonzero.") |
| if input_ids.dtype != torch.long: |
| raise TypeError("input_ids must be an long int tensor.") |
|
|
| def _validate_attention_mask( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor | None, |
| ) -> None: |
| """Validate the full-sequence attention mask.""" |
| if attention_mask is None: |
| return |
| if attention_mask.ndim != 2: |
| raise ValueError("attention_mask must have shape (batch, total_seq_len).") |
| if attention_mask.shape[0] != input_ids.shape[0]: |
| raise ValueError("attention_mask batch dimension must match input_ids.") |
| if attention_mask.shape[1] < input_ids.shape[1]: |
| raise ValueError( |
| "attention_mask must be at least as long as the current input_ids chunk." |
| ) |
|
|
| def _validate_position_ids( |
| self, |
| input_ids: torch.Tensor, |
| position_ids: torch.Tensor | None, |
| ) -> None: |
| """Validate current-step position IDs.""" |
| if position_ids is None: |
| return |
| if position_ids.ndim != 2: |
| raise ValueError("position_ids must have shape (batch, seq_len).") |
| if position_ids.shape != input_ids.shape: |
| raise ValueError( |
| "position_ids must match the current input_ids shape exactly." |
| ) |
| if input_ids.dtype != torch.long: |
| raise TypeError("position_ids must be an long tensor.") |
|
|
| def _validate_labels( |
| self, |
| input_ids: torch.Tensor, |
| labels: torch.Tensor | None, |
| ) -> None: |
| """Validate label shape at the wrapper boundary.""" |
| if labels is None: |
| return |
| if labels.ndim != 2: |
| raise ValueError("labels must have shape (batch, seq_len).") |
| if labels.shape != input_ids.shape: |
| raise ValueError("labels must have the same shape as input_ids.") |
| if input_ids.dtype != torch.long: |
| raise TypeError("labels must be a long tensor.") |
|
|
| def _validate_cache_inputs( |
| self, |
| use_cache: bool, |
| past_key_values: Cache | None, |
| ) -> None: |
| """Validate cache policy for direct wrapper calls.""" |
| if use_cache: |
| if past_key_values is None: |
| raise ValueError( |
| "use_cache=True requires an explicit ShramCache. During " |
| "generate(), HuggingFace should supply this through " |
| "_prepare_cache_for_generation()." |
| ) |
| if not isinstance(past_key_values, ShramCache): |
| raise TypeError( |
| "past_key_values must be a ShramCache when use_cache=True." |
| ) |
| return |
|
|
| if past_key_values is not None: |
| raise ValueError("past_key_values was provided while use_cache=False.") |
|
|
| def _validate_position_sources( |
| self, |
| use_cache: bool, |
| attention_mask: torch.Tensor | None, |
| position_ids: torch.Tensor | None, |
| ) -> None: |
| """Validate that cached forward has a truthful source of positions.""" |
| if use_cache and attention_mask is None and position_ids is None: |
| raise ValueError( |
| "Cached forward requires either position_ids or attention_mask." |
| ) |
|
|
| def _validate_hf_boundary( |
| self, |
| output_attentions: bool | None, |
| return_dict: bool | None, |
| inputs_embeds: torch.Tensor | None, |
| cache_position: torch.Tensor | None, |
| extra_kwargs: dict[str, Any], |
| ) -> None: |
| """Validate unsupported HuggingFace-facing wrapper inputs.""" |
| if output_attentions: |
| raise NotImplementedError( |
| "ShramForCausalLM does not expose output_attentions." |
| ) |
| if return_dict is False: |
| raise ValueError( |
| "return_dict=False is not supported. " |
| "ShramForCausalLM always returns ShramCausalLMOutput." |
| ) |
| if inputs_embeds is not None: |
| raise ValueError( |
| "inputs_embeds is not supported at the SHRAM wrapper boundary. " |
| "Pass input_ids instead." |
| ) |
| if extra_kwargs: |
| unsupported = ", ".join(sorted(extra_kwargs)) |
| raise TypeError( |
| f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}" |
| ) |
|
|
| def _standardize_full_attention_mask( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor | None, |
| ) -> torch.BoolTensor: |
| """Return a concrete full-sequence boolean attention mask.""" |
| if attention_mask is None: |
| return torch.ones_like(input_ids, dtype=torch.bool) |
| return attention_mask.to(dtype=torch.bool) |
|
|
| def _resolve_current_position_ids( |
| self, |
| input_ids: torch.Tensor, |
| position_ids: torch.Tensor | None, |
| full_attention_mask: torch.BoolTensor, |
| ) -> torch.LongTensor: |
| """Resolve concrete current-step position IDs for the backbone.""" |
| if position_ids is not None: |
| return position_ids.to(dtype=torch.long) |
|
|
| full_position_ids = full_attention_mask.to(dtype=torch.long).cumsum(dim=-1) - 1 |
| full_position_ids = full_position_ids.masked_fill(~full_attention_mask, 0) |
| current_length = input_ids.shape[1] |
| return full_position_ids[:, -current_length:] |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor | None = None, |
| position_ids: torch.Tensor | None = None, |
| past_key_values: Cache | None = None, |
| use_cache: bool | None = None, |
| output_hidden_states: bool | None = None, |
| labels: torch.Tensor | None = None, |
| return_dict: bool | None = None, |
| ce_weight: float = 1.0, |
| load_balance_weight: float = 0.01, |
| **kwargs: Any, |
| ) -> ShramCausalLMOutput: |
| """Run the SHRAM causal language model wrapper. |
| |
| Args: |
| input_ids: Current token IDs of shape ``(batch, seq_len)``. |
| attention_mask: Optional full 2D mask of shape |
| ``(batch, total_seq_len)``. The wrapper slices its recent chunk |
| to produce the current semantic liveness mask expected by the |
| backbone. |
| position_ids: Optional current-step position IDs of shape |
| ``(batch, seq_len)``. In ordinary HuggingFace generation this is |
| already the current-step tensor when it reaches ``forward()``. |
| past_key_values: Optional SHRAM cache. Required when |
| ``use_cache=True``. |
| use_cache: Whether to use and return a cache. Defaults to |
| ``config.use_cache``. |
| output_hidden_states: Whether to return backbone hidden states. |
| Defaults to ``config.output_hidden_states``. |
| labels: Optional target token IDs of shape ``(batch, seq_len)``. |
| return_dict: Must be ``True`` or ``None``. |
| ce_weight: Weight applied to the cross-entropy loss when combining with |
| the load-balance loss. Default 1.0. |
| load_balance_weight: Weight applied to the load-balance auxiliary loss. |
| Default 0.01, matching the paper's recommendation. |
| **kwargs: Unsupported HuggingFace kwargs fail explicitly. |
| |
| Returns: |
| ``ShramCausalLMOutput`` with: |
| - ``logits`` of shape ``(batch, seq_len, vocab_size)``, |
| - ``loss`` = ``ce_weight * ce_loss + load_balance_weight * load_balance_loss`` |
| when labels are provided (``None`` otherwise), |
| - ``ce_loss`` — raw unweighted cross-entropy loss for logging, |
| - ``past_key_values`` as the active ``ShramCache`` or ``None``, |
| - ``hidden_states`` when requested, |
| - ``load_balance_loss`` — raw unweighted load-balance loss from the backbone, |
| - detached ``max_vio`` from the backbone. |
| """ |
| use_cache = use_cache if use_cache is not None else self.config.use_cache |
| output_hidden_states = ( |
| output_hidden_states |
| if output_hidden_states is not None |
| else self.config.output_hidden_states |
| ) |
|
|
| inputs_embeds = kwargs.pop("inputs_embeds", None) |
| output_attentions = kwargs.pop("output_attentions", None) |
| cache_position = kwargs.pop("cache_position", None) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| self._validate_input_ids(input_ids) |
| self._validate_attention_mask(input_ids, attention_mask) |
| self._validate_position_ids(input_ids, position_ids) |
| self._validate_labels(input_ids, labels) |
| self._validate_cache_inputs(use_cache, past_key_values) |
| self._validate_position_sources(use_cache, attention_mask, position_ids) |
| self._validate_hf_boundary( |
| output_attentions=output_attentions, |
| return_dict=return_dict, |
| inputs_embeds=inputs_embeds, |
| cache_position=cache_position, |
| extra_kwargs=kwargs, |
| ) |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| full_attention_mask: torch.BoolTensor = self._standardize_full_attention_mask( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| ) |
| current_length: int = input_ids.shape[1] |
| current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:] |
| current_position_ids: torch.LongTensor = self._resolve_current_position_ids( |
| input_ids=input_ids, |
| position_ids=position_ids, |
| full_attention_mask=full_attention_mask, |
| ) |
| shram_cache: ShramCache | None = past_key_values if use_cache else None |
|
|
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| token_embeddings: torch.FloatTensor = self.embed_tokens(input_ids) |
| backbone_outputs = self.model( |
| inputs_embeds=token_embeddings, |
| position_ids=current_position_ids, |
| active_mask=current_active_mask, |
| cache=shram_cache, |
| output_hidden_states=output_hidden_states, |
| ) |
|
|
| logits: torch.FloatTensor = self.lm_head(backbone_outputs["last_hidden_state"]) |
|
|
| ce_loss: torch.FloatTensor | None = None |
| loss: torch.FloatTensor | None = None |
| if labels is not None: |
| shift_logits = logits[:, :-1, :].contiguous() |
| shift_labels = labels[:, 1:].contiguous() |
| ce_loss = nn.functional.cross_entropy( |
| shift_logits.view(-1, self.config.vocab_size), |
| shift_labels.view(-1), |
| ) |
| loss = ce_weight * ce_loss + load_balance_weight * backbone_outputs["load_balance_loss"] |
|
|
| return ShramCausalLMOutput( |
| loss=loss, |
| ce_loss=ce_loss, |
| logits=logits, |
| past_key_values=backbone_outputs["past_key_values"], |
| hidden_states=backbone_outputs["hidden_states"], |
| load_balance_loss=backbone_outputs["load_balance_loss"], |
| max_vio=backbone_outputs["max_vio"], |
| ) |