File size: 4,866 Bytes
72b2f6d 687049b 72b2f6d 7848d77 72b2f6d 687049b 72b2f6d 7848d77 72b2f6d 687049b 72b2f6d 7848d77 72b2f6d 33efa44 72b2f6d |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
# Model Adapters for True Early Exit
# Abstract interface to stop layer computation early across architectures
from abc import ABC, abstractmethod
from typing import Tuple, Optional, List, Dict, Callable
import torch
import torch.nn as nn
from torch import Tensor
class ModelAdapter(ABC):
"""Abstract interface for model internals to enable true early exit."""
@abstractmethod
def get_embed_tokens(self, input_ids: Tensor) -> Tensor:
"""Get token embeddings."""
...
@abstractmethod
def get_layers(self) -> nn.ModuleList:
"""Get list of decoder layers."""
...
@abstractmethod
def get_num_layers(self) -> int:
"""Get total number of layers."""
...
@abstractmethod
def forward_layer(
self,
layer: nn.Module,
hidden_states: Tensor,
position_ids: Tensor,
attention_mask: Optional[Tensor],
past_key_values: Optional[Tuple],
position_embeddings: Optional[Tuple],
use_cache: bool = True,
cache_position: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tuple]]:
"""Forward through a single layer, returning hidden states and optional KV cache."""
...
@abstractmethod
def apply_final_norm(self, hidden_states: Tensor) -> Tensor:
"""Apply final normalization before lm_head."""
...
@abstractmethod
def get_lm_head_output(self, hidden_states: Tensor) -> Tensor:
"""Get logits from lm_head."""
...
@abstractmethod
def get_position_embeddings(
self, hidden_states: Tensor, position_ids: Tensor
) -> Optional[Tuple[Tensor, Tensor]]:
"""Get rotary position embeddings (cos, sin) if applicable."""
...
class LlamaStyleAdapter(ModelAdapter):
"""
Adapter for Llama-style architectures.
Works for: Llama, Llama2, Llama3, Qwen, Qwen2, Qwen3, Mistral, Gemma
These models share the same internal structure:
- model.model.embed_tokens
- model.model.layers (ModuleList of decoder layers)
- model.model.norm (final RMSNorm)
- model.lm_head
- model.model.rotary_emb (RoPE embeddings)
"""
def __init__(self, model):
self.model = model
self._base = model.model
self._layers = self._base.layers
self._embed = self._base.embed_tokens
self._norm = self._base.norm
self._lm_head = model.lm_head
self._rotary = getattr(self._base, "rotary_emb", None)
self._num_layers = len(self._layers)
def get_embed_tokens(self, input_ids: Tensor) -> Tensor:
return self._embed(input_ids)
def get_layers(self) -> nn.ModuleList:
return self._layers
def get_num_layers(self) -> int:
return self._num_layers
def forward_layer(
self,
layer: nn.Module,
hidden_states: Tensor,
position_ids: Tensor,
attention_mask: Optional[Tensor],
past_key_values: Optional[Tuple],
position_embeddings: Optional[Tuple],
use_cache: bool = True,
cache_position: Optional[Tensor] = None,
) -> Tuple[Tensor, Optional[Tuple]]:
"""Forward through a decoder layer."""
layer_outputs = layer(
hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
use_cache=use_cache,
position_embeddings=position_embeddings,
cache_position=cache_position,
)
hidden_states = layer_outputs[0]
new_kv = layer_outputs[1] if len(layer_outputs) > 1 else None
return hidden_states, new_kv
def apply_final_norm(self, hidden_states: Tensor) -> Tensor:
return self._norm(hidden_states)
def get_lm_head_output(self, hidden_states: Tensor) -> Tensor:
return self._lm_head(hidden_states)
def get_position_embeddings(
self, hidden_states: Tensor, position_ids: Tensor
) -> Optional[Tuple[Tensor, Tensor]]:
if self._rotary is not None:
cos, sin = self._rotary(hidden_states, position_ids)
# Return as-is - the model's apply_rotary_pos_emb handles unsqueezing
return (cos, sin)
return None
def get_adapter(model) -> ModelAdapter:
"""
Factory function to get the appropriate adapter for a model.
Currently supports Llama-style models (Llama, Qwen, Mistral, Gemma).
"""
# Check for Llama-style architecture
if hasattr(model, "model") and hasattr(model.model, "layers"):
return LlamaStyleAdapter(model)
# GPT-2 style (transformer.h)
if hasattr(model, "transformer") and hasattr(model.transformer, "h"):
raise NotImplementedError("GPT-2 style models not yet supported")
raise ValueError(f"Unsupported model architecture: {type(model)}")
|