SHRAM / huggingface.py
smithblack-0's picture
Update architecture and tokenizer
78610c2 verified
"""HuggingFace causal-LM wrapper for SHRAM.
ShramForCausalLM is the HuggingFace-facing language-model boundary for SHRAM.
It owns token embedding lookup, LM-head projection, wrapper-level next-token
cross-entropy loss, config-controlled tied embeddings, and generation/cache
orchestration at the wrapper boundary.
The backbone remains a pure transformer stack. ShramModel accepts pre-embedded
hidden states together with current position IDs, a current active mask, and an
optional ShramCache. It has no knowledge of token IDs, vocabulary projection,
or causal-LM loss.
HuggingFace generation reaches this wrapper with two different tensor
conventions:
- ``position_ids`` is a current-step tensor. GenerationMixin updates the total
sequence state between steps, then slices position-bearing tensors back down
before calling ``forward()``.
- ``attention_mask`` is a full 2D mask over the total sequence so far. This
wrapper slices its recent chunk to produce the current semantic liveness mask
expected by the backbone.
Generation-created caches are handled in ``_prepare_cache_for_generation``.
That hook ensures HuggingFace generation uses ShramCache rather than a generic
dynamic cache. The direct ``forward()`` path does not silently create caches;
when ``use_cache=True`` it expects a truthful ShramCache to have been supplied.
"""
from dataclasses import dataclass
from typing import Any
import torch
import torch.nn as nn
from transformers import GenerationMixin, PreTrainedModel
from transformers.cache_utils import Cache
from transformers.generation.configuration_utils import GenerationMode
from transformers.modeling_outputs import CausalLMOutputWithPast
from .__cache__shram_cache import ShramCache
from .configuration import ShramConfig
from .model import ShramModel
@dataclass
class ShramCausalLMOutput(CausalLMOutputWithPast):
"""SHRAM causal-LM wrapper output.
This subclasses HuggingFace's standard ``CausalLMOutputWithPast``.
Dataclass inheritance is sufficient here: all standard causal-LM fields and
ModelOutput behavior are inherited from the parent, and this subclass adds
only the SHRAM-specific wrapper outputs.
"""
ce_loss: torch.FloatTensor | None = None
load_balance_loss: torch.FloatTensor | None = None
max_vio: torch.FloatTensor | None = None
class ShramForCausalLM(PreTrainedModel, GenerationMixin):
"""HuggingFace-facing causal language model wrapper for SHRAM.
Owns token embeddings, LM-head projection, wrapper-level shifted CE loss,
tied embedding configuration, and generation/cache boundary behavior.
Delegates all transformer computation to ``ShramModel``.
Args:
config: SHRAM model configuration.
"""
config_class = ShramConfig
base_model_prefix = "model"
_no_split_modules = ["DecoderLayer"]
supports_gradient_checkpointing = True
def __init__(self, config: ShramConfig) -> None:
super().__init__(config)
self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size)
self.model = ShramModel(config)
self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)
self._configure_tied_embeddings()
self.post_init()
def _configure_tied_embeddings(self) -> None:
"""Apply config-controlled tied embedding behavior on this instance."""
if self.config.tie_word_embeddings:
self.lm_head.weight = self.embed_tokens.weight
self._tied_weights_keys = {
"lm_head.weight": "embed_tokens.weight",
}
else:
self._tied_weights_keys = {}
def num_mosrah_parameters(self) -> int:
"""Return the total number of trainable parameters belonging to MoSRAH layers.
Aggregates across all decoder layers. Excludes sliding-window path parameters,
FFN parameters, norms, and embeddings. Use this for experimental plotting of
MoSRAH parameter count versus performance.
Returns:
Total count of trainable MoSRAH parameters.
"""
return self.model.num_mosrah_parameters()
def get_input_embeddings(self) -> nn.Embedding:
"""Return the token embedding matrix."""
return self.embed_tokens
def set_input_embeddings(self, value: nn.Embedding) -> None:
"""Replace the token embedding matrix."""
self.embed_tokens = value
self._configure_tied_embeddings()
def get_output_embeddings(self) -> nn.Linear:
"""Return the LM head."""
return self.lm_head
def set_output_embeddings(self, value: nn.Linear) -> None:
"""Replace the LM head."""
self.lm_head = value
self._configure_tied_embeddings()
def _build_shram_cache(
self,
batch_size: int,
device: torch.device,
) -> ShramCache:
"""Construct a fresh top-level SHRAM cache."""
return ShramCache(
num_hidden_layers=self.config.num_hidden_layers,
sliding_window=self.config.window_size,
num_local_heads=self.config.num_sliding_window_heads,
local_head_dim=self.config.head_dim,
num_mosrah_heads=self.config.num_mosrah_heads,
mosrah_head_dim=self.config.head_dim,
batch_size=batch_size,
device=device,
)
def _validate_generation_cache_request(
self,
generation_config: Any,
model_kwargs: dict[str, Any],
generation_mode: GenerationMode,
) -> None:
"""Validate SHRAM's generation-side cache policy."""
if generation_mode in {
GenerationMode.ASSISTED_GENERATION,
GenerationMode.CONTRASTIVE_SEARCH,
}:
raise NotImplementedError(
"ShramForCausalLM does not currently support assisted generation "
"or contrastive search because ShramCache does not support crop()."
)
user_defined_cache = model_kwargs.get("past_key_values")
if user_defined_cache is not None:
if generation_config.cache_implementation is not None:
raise ValueError(
"Passing both `cache_implementation` and `past_key_values` "
"is unsupported. Please use only one."
)
if isinstance(user_defined_cache, tuple):
raise ValueError(
"Passing a tuple of `past_key_values` is not supported. "
"Please use a `ShramCache` instance."
)
if not isinstance(user_defined_cache, ShramCache):
raise TypeError(
"ShramForCausalLM requires `past_key_values` to be a "
"`ShramCache` instance."
)
if (
user_defined_cache is None
and generation_config.use_cache
and generation_config.cache_implementation is not None
):
raise ValueError(
"ShramForCausalLM does not support `cache_implementation`. "
"Generation-created caches must be `ShramCache` objects."
)
def _prepare_cache_for_generation(
self,
generation_config: Any,
model_kwargs: dict[str, Any],
generation_mode: GenerationMode,
batch_size: int,
max_cache_length: int,
) -> None:
"""Ensure HuggingFace generation uses ShramCache.
This is the SHRAM-specific generation hook. The rest of the default
generation plumbing is kept intact as much as possible.
Args:
generation_config: Active generation configuration.
model_kwargs: Generation kwargs, updated in place.
generation_mode: HuggingFace generation mode.
batch_size: Effective generation batch size.
max_cache_length: Requested cache length. Accepted but unused here.
"""
self._validate_generation_cache_request(
generation_config=generation_config,
model_kwargs=model_kwargs,
generation_mode=generation_mode,
)
if model_kwargs.get("past_key_values") is not None:
return
if not generation_config.use_cache:
return
num_repeats = max(
generation_config.num_beams or 1,
generation_config.num_return_sequences or 1,
)
model_kwargs["past_key_values"] = self._build_shram_cache(
batch_size=batch_size*num_repeats,
device=self.embed_tokens.weight.device,
)
def _reorder_cache(
self,
past_key_values: Cache,
beam_idx: torch.Tensor,
) -> Cache:
"""Reorder the cache in place for beam search."""
past_key_values.reorder_cache(beam_idx)
return past_key_values
def _validate_input_ids(self, input_ids: torch.Tensor) -> None:
"""Validate token IDs at the wrapper boundary."""
if input_ids.ndim != 2:
raise ValueError("input_ids must have shape (batch, seq_len).")
if input_ids.shape[1] == 0:
raise ValueError("input_ids sequence length must be nonzero.")
if input_ids.dtype != torch.long:
raise TypeError("input_ids must be an long int tensor.")
def _validate_attention_mask(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None,
) -> None:
"""Validate the full-sequence attention mask."""
if attention_mask is None:
return
if attention_mask.ndim != 2:
raise ValueError("attention_mask must have shape (batch, total_seq_len).")
if attention_mask.shape[0] != input_ids.shape[0]:
raise ValueError("attention_mask batch dimension must match input_ids.")
if attention_mask.shape[1] < input_ids.shape[1]:
raise ValueError(
"attention_mask must be at least as long as the current input_ids chunk."
)
def _validate_position_ids(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor | None,
) -> None:
"""Validate current-step position IDs."""
if position_ids is None:
return
if position_ids.ndim != 2:
raise ValueError("position_ids must have shape (batch, seq_len).")
if position_ids.shape != input_ids.shape:
raise ValueError(
"position_ids must match the current input_ids shape exactly."
)
if input_ids.dtype != torch.long:
raise TypeError("position_ids must be an long tensor.")
def _validate_labels(
self,
input_ids: torch.Tensor,
labels: torch.Tensor | None,
) -> None:
"""Validate label shape at the wrapper boundary."""
if labels is None:
return
if labels.ndim != 2:
raise ValueError("labels must have shape (batch, seq_len).")
if labels.shape != input_ids.shape:
raise ValueError("labels must have the same shape as input_ids.")
if input_ids.dtype != torch.long:
raise TypeError("labels must be a long tensor.")
def _validate_cache_inputs(
self,
use_cache: bool,
past_key_values: Cache | None,
) -> None:
"""Validate cache policy for direct wrapper calls."""
if use_cache:
if past_key_values is None:
raise ValueError(
"use_cache=True requires an explicit ShramCache. During "
"generate(), HuggingFace should supply this through "
"_prepare_cache_for_generation()."
)
if not isinstance(past_key_values, ShramCache):
raise TypeError(
"past_key_values must be a ShramCache when use_cache=True."
)
return
if past_key_values is not None:
raise ValueError("past_key_values was provided while use_cache=False.")
def _validate_position_sources(
self,
use_cache: bool,
attention_mask: torch.Tensor | None,
position_ids: torch.Tensor | None,
) -> None:
"""Validate that cached forward has a truthful source of positions."""
if use_cache and attention_mask is None and position_ids is None:
raise ValueError(
"Cached forward requires either position_ids or attention_mask."
)
def _validate_hf_boundary(
self,
output_attentions: bool | None,
return_dict: bool | None,
inputs_embeds: torch.Tensor | None,
cache_position: torch.Tensor | None,
extra_kwargs: dict[str, Any],
) -> None:
"""Validate unsupported HuggingFace-facing wrapper inputs."""
if output_attentions:
raise NotImplementedError(
"ShramForCausalLM does not expose output_attentions."
)
if return_dict is False:
raise ValueError(
"return_dict=False is not supported. "
"ShramForCausalLM always returns ShramCausalLMOutput."
)
if inputs_embeds is not None:
raise ValueError(
"inputs_embeds is not supported at the SHRAM wrapper boundary. "
"Pass input_ids instead."
)
if extra_kwargs:
unsupported = ", ".join(sorted(extra_kwargs))
raise TypeError(
f"Unsupported forward kwargs for ShramForCausalLM: {unsupported}"
)
def _standardize_full_attention_mask(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None,
) -> torch.BoolTensor:
"""Return a concrete full-sequence boolean attention mask."""
if attention_mask is None:
return torch.ones_like(input_ids, dtype=torch.bool)
return attention_mask.to(dtype=torch.bool)
def _resolve_current_position_ids(
self,
input_ids: torch.Tensor,
position_ids: torch.Tensor | None,
full_attention_mask: torch.BoolTensor,
) -> torch.LongTensor:
"""Resolve concrete current-step position IDs for the backbone."""
if position_ids is not None:
return position_ids.to(dtype=torch.long)
full_position_ids = full_attention_mask.to(dtype=torch.long).cumsum(dim=-1) - 1
full_position_ids = full_position_ids.masked_fill(~full_attention_mask, 0)
current_length = input_ids.shape[1]
return full_position_ids[:, -current_length:]
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor | None = None,
position_ids: torch.Tensor | None = None,
past_key_values: Cache | None = None,
use_cache: bool | None = None,
output_hidden_states: bool | None = None,
labels: torch.Tensor | None = None,
return_dict: bool | None = None,
ce_weight: float = 1.0,
load_balance_weight: float = 0.01,
**kwargs: Any,
) -> ShramCausalLMOutput:
"""Run the SHRAM causal language model wrapper.
Args:
input_ids: Current token IDs of shape ``(batch, seq_len)``.
attention_mask: Optional full 2D mask of shape
``(batch, total_seq_len)``. The wrapper slices its recent chunk
to produce the current semantic liveness mask expected by the
backbone.
position_ids: Optional current-step position IDs of shape
``(batch, seq_len)``. In ordinary HuggingFace generation this is
already the current-step tensor when it reaches ``forward()``.
past_key_values: Optional SHRAM cache. Required when
``use_cache=True``.
use_cache: Whether to use and return a cache. Defaults to
``config.use_cache``.
output_hidden_states: Whether to return backbone hidden states.
Defaults to ``config.output_hidden_states``.
labels: Optional target token IDs of shape ``(batch, seq_len)``.
return_dict: Must be ``True`` or ``None``.
ce_weight: Weight applied to the cross-entropy loss when combining with
the load-balance loss. Default 1.0.
load_balance_weight: Weight applied to the load-balance auxiliary loss.
Default 0.01, matching the paper's recommendation.
**kwargs: Unsupported HuggingFace kwargs fail explicitly.
Returns:
``ShramCausalLMOutput`` with:
- ``logits`` of shape ``(batch, seq_len, vocab_size)``,
- ``loss`` = ``ce_weight * ce_loss + load_balance_weight * load_balance_loss``
when labels are provided (``None`` otherwise),
- ``ce_loss`` — raw unweighted cross-entropy loss for logging,
- ``past_key_values`` as the active ``ShramCache`` or ``None``,
- ``hidden_states`` when requested,
- ``load_balance_loss`` — raw unweighted load-balance loss from the backbone,
- detached ``max_vio`` from the backbone.
"""
use_cache = use_cache if use_cache is not None else self.config.use_cache
output_hidden_states = (
output_hidden_states
if output_hidden_states is not None
else self.config.output_hidden_states
)
inputs_embeds = kwargs.pop("inputs_embeds", None)
output_attentions = kwargs.pop("output_attentions", None)
cache_position = kwargs.pop("cache_position", None)
# ------------------------------------------------------------------
# Validation zone.
#
# The wrapper boundary is where HuggingFace-facing inputs are judged
# for truthfulness before any internal work begins. These checks are
# intentionally front-loaded so the core logic below can assume one
# coherent interpretation of the call rather than defensively checking
# shapes, cache policy, or unsupported HF knobs at the point of use.
# This keeps the main sequence readable while ensuring invalid states
# fail before they can silently contaminate backbone execution.
# ------------------------------------------------------------------
self._validate_input_ids(input_ids)
self._validate_attention_mask(input_ids, attention_mask)
self._validate_position_ids(input_ids, position_ids)
self._validate_labels(input_ids, labels)
self._validate_cache_inputs(use_cache, past_key_values)
self._validate_position_sources(use_cache, attention_mask, position_ids)
self._validate_hf_boundary(
output_attentions=output_attentions,
return_dict=return_dict,
inputs_embeds=inputs_embeds,
cache_position=cache_position,
extra_kwargs=kwargs,
)
# ------------------------------------------------------------------
# Standardization zone.
#
# HuggingFace and SHRAM use different boundary conventions: generation
# carries a full-sequence 2D attention mask, while the SHRAM backbone
# wants a current-step active mask and concrete current position IDs.
# This zone collapses those wrapper-facing conventions into one valid
# backbone-facing state. After this point the core no longer reasons
# about optional or ambiguous input forms; it works only with concrete
# tensors whose semantics are already fixed.
# ------------------------------------------------------------------
full_attention_mask: torch.BoolTensor = self._standardize_full_attention_mask(
input_ids=input_ids,
attention_mask=attention_mask,
)
current_length: int = input_ids.shape[1]
current_active_mask: torch.BoolTensor = full_attention_mask[:, -current_length:]
current_position_ids: torch.LongTensor = self._resolve_current_position_ids(
input_ids=input_ids,
position_ids=position_ids,
full_attention_mask=full_attention_mask,
)
shram_cache: ShramCache | None = past_key_values if use_cache else None
# ------------------------------------------------------------------
# Core wrapper responsibilities.
#
# The wrapper's primary job is kept visible here: convert token IDs to
# embeddings, delegate transformer computation to ShramModel, project
# hidden states back to vocabulary logits, optionally compute the
# wrapper-level shifted next-token loss, and return the HuggingFace-
# facing output object. The backbone remains responsible only for
# transformer semantics; token/vocabulary/loss concerns stay here.
# ------------------------------------------------------------------
token_embeddings: torch.FloatTensor = self.embed_tokens(input_ids)
backbone_outputs = self.model(
inputs_embeds=token_embeddings,
position_ids=current_position_ids,
active_mask=current_active_mask,
cache=shram_cache,
output_hidden_states=output_hidden_states,
)
logits: torch.FloatTensor = self.lm_head(backbone_outputs["last_hidden_state"])
ce_loss: torch.FloatTensor | None = None
loss: torch.FloatTensor | None = None
if labels is not None:
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = labels[:, 1:].contiguous()
ce_loss = nn.functional.cross_entropy(
shift_logits.view(-1, self.config.vocab_size),
shift_labels.view(-1),
)
loss = ce_weight * ce_loss + load_balance_weight * backbone_outputs["load_balance_loss"]
return ShramCausalLMOutput(
loss=loss,
ce_loss=ce_loss,
logits=logits,
past_key_values=backbone_outputs["past_key_values"],
hidden_states=backbone_outputs["hidden_states"],
load_balance_loss=backbone_outputs["load_balance_loss"],
max_vio=backbone_outputs["max_vio"],
)