| """Modeling file for HF compatibility and zero-shot experiments.""" |
|
|
| import torch |
| import math |
|
|
| from torch import Tensor |
| from torch.nn.attention.flex_attention import create_block_mask, BlockMask, flex_attention |
| from torch.nn.attention import bias as attn_bias |
| from dataclasses import dataclass |
| from typing import Union, Optional, Any |
|
|
|
|
| from .raven_config_minimal import RavenConfig |
| from transformers.cache_utils import Cache, DynamicCache, StaticCache |
|
|
| |
| from transformers import PreTrainedModel, GenerationMixin |
| from transformers.utils import ModelOutput |
| from transformers.generation.utils import GenerateDecoderOnlyOutput |
|
|
| import torch.nn.functional as F |
| from transformers import GenerationConfig |
| from transformers.models.olmo2.modeling_olmo2 import Olmo2RotaryEmbedding, apply_rotary_pos_emb |
|
|
| torch.backends.cuda.enable_math_sdp(False) |
|
|
|
|
| class RavenPreTrainedModel(PreTrainedModel): |
| config_class = RavenConfig |
| base_model_prefix = "model" |
| supports_gradient_checkpointing = True |
| _no_split_modules = ["SandwichBlock"] |
| _skip_keys_device_placement = ["past_key_values"] |
| _tied_weights_keys = ["lm_head.weight"] |
| _supports_flash_attn_2 = True |
| _supports_sdpa = True |
| _supports_cache_class = True |
| _supports_quantized_cache = False |
| _supports_static_cache = True |
| _tp_plan = {} |
|
|
| def _init_weights(self, module): |
| if not torch.rand((1,)).is_meta: |
| print("Random Initialization not implemented.") |
|
|
|
|
| @dataclass |
| class CausalLMOutputRecurrentLatents(ModelOutput): |
| loss: Optional[torch.Tensor] = None |
| log_ppl: Optional[torch.Tensor] = None |
| logits: Optional[torch.Tensor] = None |
| past_key_values: Optional[Cache] = None |
| latent_states: Optional[torch.Tensor] = None |
| hidden_states: Optional[torch.Tensor] = None |
| attention_maps: Optional[dict[int, torch.Tensor]] = None |
| stats: Optional[dict] = None |
|
|
|
|
| |
|
|
|
|
| class RMSNorm(torch.nn.Module): |
| """Saner dtype handling and slightly better for fusion""" |
|
|
| def __init__(self, dim: int, eps: float = 1e-6): |
| super().__init__() |
| self.eps = eps |
| self.weight = torch.nn.Parameter(torch.ones(dim)) |
|
|
| def _norm(self, x): |
| return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) |
|
|
| def forward(self, x): |
| with torch.autocast(enabled=False, device_type=x.device.type if x.device.type != "meta" else "cuda"): |
| return self._norm(x.float()).type_as(x) * self.weight |
|
|
| def reset_parameters(self) -> None: |
| torch.nn.init.ones_(self.weight) |
|
|
|
|
| class HuginnDynamicCache(DynamicCache): |
| def __init__(self, lookup_strategy: str = "full") -> None: |
| super().__init__() |
| self._seen_tokens = 0 |
| self.key_cache: dict[int, dict[int, torch.Tensor]] = {} |
| self.value_cache: dict[int, dict[int, torch.Tensor]] = {} |
| |
| |
| |
| |
| self.lookup_strategy = lookup_strategy |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| step_idx_tensor: torch.Tensor, |
| lookup_strategy: Optional[str] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| step_idx: int = int(step_idx_tensor) |
| lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy |
| if "compress-" in self.lookup_strategy and step_idx > 1: |
| if "compress-s" in self.lookup_strategy: |
| compression_stage = int(self.lookup_strategy.split("compress-")[1][1:]) |
| new_step_idx = (step_idx - 2) % compression_stage + 2 |
| elif "compress-anchor" in self.lookup_strategy: |
| if step_idx - 2 < 4 * 8: |
| new_step_idx = step_idx |
| else: |
| new_step_idx = 34 + (step_idx - 34) % 4 |
| |
| else: |
| compression_stage = int(self.lookup_strategy.split("compress-")[1][1:]) |
| new_step_idx = (step_idx - 2) // compression_stage + 2 |
| step_idx = new_step_idx |
| |
| if step_idx not in self.key_cache: |
| self.key_cache[step_idx] = {} |
| self.value_cache[step_idx] = {} |
| |
| if step_idx == 0: |
| self._seen_tokens += key_states.shape[-2] |
| |
| for idx, entry in enumerate(key_states.unbind(dim=-2)): |
| if "compress-" not in self.lookup_strategy: |
| assert step_idx < 0 or self._seen_tokens - key_states.shape[-2] + idx not in self.key_cache[step_idx] |
| self.key_cache[step_idx][self._seen_tokens - key_states.shape[-2] + idx] = entry |
| for idx, entry in enumerate(value_states.unbind(dim=-2)): |
| self.value_cache[step_idx][self._seen_tokens - value_states.shape[-2] + idx] = entry |
|
|
| |
| if len(self.key_cache[step_idx]) == self._seen_tokens or self.lookup_strategy == "full": |
| |
| return ( |
| torch.stack(list(self.key_cache[step_idx].values()), dim=-2), |
| torch.stack(list(self.value_cache[step_idx].values()), dim=-2), |
| ) |
| else: |
| if lookup_strategy.startswith("latest-m4"): |
| latest_keys = [] |
| latest_values = [] |
| for token_pos in range(self._seen_tokens): |
| |
| if step_idx >= 2: |
| |
| valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]] |
| max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4]) |
| else: |
| max_step = step_idx if token_pos in self.key_cache[step_idx] else 0 |
| latest_keys.append(self.key_cache[max_step][token_pos]) |
| latest_values.append(self.value_cache[max_step][token_pos]) |
| return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2) |
| elif lookup_strategy.startswith("available-m4"): |
| latest_keys = [] |
| latest_values = [] |
| for token_pos in range(self._seen_tokens): |
| if token_pos in self.key_cache[step_idx]: |
| step = step_idx |
| else: |
| |
| valid_steps = [s for s in range(step_idx + 1) if token_pos in self.key_cache[s]] |
| step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4]) |
| latest_keys.append(self.key_cache[step][token_pos]) |
| latest_values.append(self.value_cache[step][token_pos]) |
| return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2) |
| elif lookup_strategy.startswith("always-last-m4"): |
| latest_keys = [] |
| latest_values = [] |
| for token_pos in range(self._seen_tokens): |
| |
| if step_idx >= 2: |
| |
| valid_steps = [key_step for key_step in self.key_cache if token_pos in self.key_cache[key_step]] |
| max_step = max([s for s in valid_steps if s >= 2 and s % 4 == step_idx % 4]) |
| else: |
| max_step = step_idx if token_pos in self.key_cache[step_idx] else 0 |
| latest_keys.append(self.key_cache[max_step][token_pos]) |
| latest_values.append(self.value_cache[max_step][token_pos]) |
| return torch.stack(latest_keys, dim=-2), torch.stack(latest_values, dim=-2) |
| elif lookup_strategy.startswith("skip"): |
| existing_keys = [] |
| existing_values = [] |
| for token_pos in range(self._seen_tokens): |
| if token_pos in self.key_cache[step_idx]: |
| existing_keys.append(self.key_cache[step_idx][token_pos]) |
| existing_values.append(self.value_cache[step_idx][token_pos]) |
| return torch.stack(existing_keys, dim=-2), torch.stack(existing_values, dim=-2) |
| elif lookup_strategy.startswith("randomized"): |
| rand_keys = [] |
| rand_values = [] |
| for token_pos in range(self._seen_tokens): |
| if step_idx < 2: |
| max_step = step_idx if token_pos in self.key_cache[step_idx] else 0 |
| else: |
| curr_modulo = (step_idx - 2) % 4 + 2 |
| valid_steps = [ |
| s |
| for s in range(2, step_idx + 1) |
| if (s - 2) % 4 + 2 == curr_modulo and token_pos in self.key_cache[s] |
| ] |
| max_step = valid_steps[torch.randint(len(valid_steps), (1,))] |
| rand_keys.append(self.key_cache[max_step][token_pos]) |
| rand_values.append(self.value_cache[max_step][token_pos]) |
| return torch.stack(rand_keys, dim=-2), torch.stack(rand_values, dim=-2) |
| else: |
| raise ValueError(f"Unknown lookup strategy: {lookup_strategy}") |
|
|
| def reset(self) -> None: |
| """Reset the cache state.""" |
| self._seen_tokens = 0 |
| self.key_cache.clear() |
| self.value_cache.clear() |
|
|
| def clear_last_k_entries(self, k: int = 0): |
| """Partially clear cache.""" |
| assert self._seen_tokens >= k |
| self._seen_tokens = self._seen_tokens - k |
| |
| self.key_cache = { |
| step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens} |
| for step, cache in self.key_cache.items() |
| } |
| self.value_cache = { |
| step: {seq: seq_cache for seq, seq_cache in cache.items() if seq < self._seen_tokens} |
| for step, cache in self.value_cache.items() |
| } |
|
|
| def get_seq_length(self, step_idx: int = 0) -> int: |
| return self._seen_tokens |
|
|
| def get_memory_usage(self) -> float: |
| total_bytes = 0 |
| |
| for step_idx in self.key_cache: |
| |
| key_seq_cache = self.key_cache[step_idx] |
| for seq_idx in key_seq_cache: |
| key_tensor = key_seq_cache[seq_idx] |
| |
| total_bytes += key_tensor.nelement() * key_tensor.element_size() |
| return total_bytes * 2 / (1024 * 1024) |
|
|
|
|
| class HuginnStaticCache(Cache): |
| """Static Cache for the recurrent model""" |
|
|
| is_compileable = False |
|
|
| def __init__( |
| self, |
| max_length: int, |
| max_num_steps: int, |
| num_heads: int, |
| hidden_dim: int, |
| batch_size: int = 1, |
| lookup_strategy: str = "full", |
| device: Optional[Union[torch.device, str]] = None, |
| dtype: torch.dtype = torch.float32, |
| ) -> None: |
| super().__init__() |
| self._seen_tokens = 0 |
| self.max_length = max_length |
| self.lookup_strategy = lookup_strategy |
|
|
| |
| if "compress-" in lookup_strategy: |
| compression_stage = int(lookup_strategy.split("compress-")[1][1:]) |
| if "compress-s" in lookup_strategy: |
| |
| self.max_num_steps = 4 + compression_stage |
| else: |
| |
| self.max_num_steps = 4 + (max_num_steps - 4 + compression_stage - 1) // compression_stage |
| else: |
| self.max_num_steps = max_num_steps |
|
|
| |
| device = torch.device(device) if device is not None else None |
| cache_shape = (self.max_num_steps, batch_size, num_heads, max_length, hidden_dim) |
|
|
| self.key_cache = torch.zeros(cache_shape, dtype=dtype, device=device) |
| self.value_cache = torch.zeros(cache_shape, dtype=dtype, device=device) |
| self.valid_mask = torch.zeros((self.max_num_steps, max_length), dtype=torch.bool, device=device) |
| |
| torch._dynamo.mark_static_address(self.key_cache) |
| torch._dynamo.mark_static_address(self.value_cache) |
| torch._dynamo.mark_static_address(self.valid_mask) |
|
|
| def update( |
| self, |
| key_states: torch.Tensor, |
| value_states: torch.Tensor, |
| step_idx: torch.Tensor, |
| lookup_strategy: Optional[str] = None, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| if step_idx == 0: |
| self._seen_tokens += key_states.shape[-2] |
|
|
| |
| lookup_strategy = self.lookup_strategy if lookup_strategy is None else lookup_strategy |
| if "compress-" in lookup_strategy and step_idx > 1: |
| compression_stage = int(lookup_strategy.split("compress-")[1][1:]) |
| if "compress-s" in lookup_strategy: |
| step_idx = (step_idx - 2) % compression_stage + 2 |
| else: |
| step_idx = (step_idx - 2) // compression_stage + 2 |
|
|
| start_idx = self._seen_tokens - key_states.shape[-2] |
|
|
| indices = torch.arange(start_idx, start_idx + key_states.shape[-2], device=key_states.device) |
| self.key_cache[step_idx].index_copy_(2, indices, key_states) |
| self.value_cache[step_idx].index_copy_(2, indices, value_states) |
| self.valid_mask[step_idx, start_idx : start_idx + key_states.shape[-2]] = True |
|
|
| |
| if lookup_strategy == "full": |
| return ( |
| self.key_cache[step_idx, :, :, : self._seen_tokens], |
| self.value_cache[step_idx, :, :, : self._seen_tokens], |
| ) |
| elif lookup_strategy.startswith("latest-m4"): |
| if step_idx >= 2: |
| pattern_steps = torch.arange(2, step_idx.item() + 1, 4, device=self.valid_mask.device) |
| pattern_valid = self.valid_mask[pattern_steps] |
| max_valid_step = pattern_steps[pattern_valid.to(torch.long).argmax(dim=0)] |
| return ( |
| self.key_cache[max_valid_step, torch.arange(self._seen_tokens)], |
| self.value_cache[max_valid_step, torch.arange(self._seen_tokens)], |
| ) |
| return self.key_cache[step_idx, :, :, : self._seen_tokens], self.value_cache[ |
| step_idx, :, :, : self._seen_tokens |
| ] |
| elif lookup_strategy == "skip": |
| valid_mask = self.valid_mask[step_idx, : self._seen_tokens] |
| return ( |
| self.key_cache[step_idx, :, :, : self._seen_tokens][valid_mask], |
| self.value_cache[step_idx, :, :, : self._seen_tokens][valid_mask], |
| ) |
| elif lookup_strategy.startswith("randomized"): |
| if step_idx < 2: |
| max_step = step_idx |
| else: |
| curr_modulo = (step_idx - 2) % 4 + 2 |
| valid_steps = ( |
| torch.where( |
| (torch.arange(2, step_idx.item() + 1, device=self.valid_mask.device) - 2) % 4 + 2 == curr_modulo |
| )[0] |
| + 2 |
| ) |
| rand_idx = torch.randint(len(valid_steps), (1,), device=valid_steps.device) |
| max_step = valid_steps[rand_idx] |
| return self.key_cache[max_step, : self._seen_tokens], self.value_cache[max_step, : self._seen_tokens] |
| else: |
| raise ValueError(f"Unknown lookup strategy: {lookup_strategy}") |
|
|
| def reset(self) -> None: |
| self._seen_tokens = 0 |
| self.key_cache.zero_() |
| self.value_cache.zero_() |
| self.valid_mask.zero_() |
|
|
| def get_seq_length(self, step_idx: int = 0) -> int: |
| return self._seen_tokens |
|
|
| def get_memory_usage(self) -> float: |
| return (self.key_cache.nelement() + self.value_cache.nelement()) * self.key_cache.element_size() / (1024 * 1024) |
|
|
|
|
| ValidCache = HuginnDynamicCache | HuginnStaticCache |
|
|
|
|
| class CausalSelfAttention(torch.nn.Module): |
| def __init__(self, config: RavenConfig) -> None: |
| super().__init__() |
| self.config = config |
| self.n_head = config.num_attention_heads |
| self.n_kv_heads = config.num_key_value_heads |
| self.head_dim = getattr(config, "head_dim", config.n_embd // self.n_head) |
|
|
| shape = (self.n_head + 2 * self.n_kv_heads) * self.head_dim |
| self.chunks = [self.n_head * self.head_dim, self.n_kv_heads * self.head_dim, self.n_kv_heads * self.head_dim] |
|
|
| self.Wqkv = torch.nn.Linear(config.n_embd, shape, bias=False) |
| if config.qk_bias: |
| self.qk_bias = torch.nn.Parameter(torch.zeros(2, 1, self.n_head, self.head_dim)) |
| self.q_norm = RMSNorm(config.num_attention_heads * config.head_dim, eps=config.norm_eps) |
| self.k_norm = RMSNorm(config.num_key_value_heads * config.head_dim, eps=config.norm_eps) |
| self.proj = torch.nn.Linear(self.n_head * self.head_dim, config.n_embd, bias=False) |
|
|
| def forward( |
| self, |
| x: Tensor, |
| freqs_cis: Tensor, |
| block_idx: torch.Tensor, |
| mask: Optional[BlockMask] = None, |
| past_key_values: Optional[ValidCache] = None, |
| ) -> Tensor: |
| B, S, E = x.shape |
| q, k, v = self.Wqkv(x).split(self.chunks, dim=2) |
|
|
| q = self.q_norm(q) |
| k = self.k_norm(k) |
|
|
| q = q.view(B, S, self.n_head, self.head_dim) |
| k = k.view(B, S, self.n_kv_heads, self.head_dim) |
| v = v.view(B, S, self.n_kv_heads, self.head_dim) |
| |
| if self.config.qk_bias: |
| q_bias, k_bias = self.qk_bias.split(1, dim=0) |
| q, k = (q + q_bias).to(q.dtype), (k + k_bias).to(q.dtype) |
|
|
| q = q.transpose(1, 2) |
| k = k.transpose(1, 2) |
| v = v.transpose(1, 2) |
|
|
| |
| cos, sin = freqs_cis |
| q, k = apply_rotary_pos_emb(q, k, cos, sin) |
|
|
| if past_key_values is not None: |
| k, v = past_key_values.update(k, v, block_idx) |
|
|
| if mask is not None: |
| y: torch.Tensor = flex_attention(q, k, v, block_mask=mask) |
| else: |
| if q.shape[2] < k.shape[2]: |
| if q.shape[2] > 1: |
| bias = attn_bias.causal_lower_right(q.shape[2], k.shape[2]) |
| y = torch.nn.functional.scaled_dot_product_attention(q, k, v, bias, dropout_p=0.0, enable_gqa=True) |
| else: |
| y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=False, enable_gqa=True) |
| else: |
| y = torch.nn.functional.scaled_dot_product_attention(q, k, v, dropout_p=0.0, is_causal=True, enable_gqa=True) |
| y = y.transpose(1, 2).reshape(B, S, self.n_head * self.head_dim).contiguous() |
| return self.proj(y) |
|
|
|
|
| class GatedMLP(torch.nn.Module): |
| def __init__(self, config: RavenConfig, in_features: int = 0) -> None: |
| super().__init__() |
| in_features = config.n_embd if in_features == 0 else in_features |
| self.fc = torch.nn.Linear(in_features, config.intermediate_size * 2, bias=False) |
|
|
| self.proj = torch.nn.Linear(config.intermediate_size, config.n_embd, bias=False) |
| self.nonlin = torch.nn.SiLU() |
|
|
| def forward(self, x: Tensor) -> Tensor: |
| |
| x_fc_1, x_fc_2 = self.fc(x).chunk(2, dim=-1) |
| x = self.nonlin(x_fc_1) * x_fc_2 |
| return self.proj(x) |
|
|
|
|
| class SandwichBlock(torch.nn.Module): |
| expanded = False |
|
|
| def __init__(self, config: RavenConfig, layer_id: int) -> None: |
| super().__init__() |
| self.norm_1 = RMSNorm(config.n_embd, eps=config.norm_eps) |
| self.attn = CausalSelfAttention(config) |
| self.norm_2 = RMSNorm(config.n_embd, eps=config.norm_eps) |
| self.mlp = GatedMLP(config) |
| self.layer_id = layer_id |
|
|
| def forward( |
| self, |
| x: Tensor, |
| freqs_cis: Tensor, |
| step_idx: int, |
| mask: Optional[BlockMask] = None, |
| past_key_values: Optional[ValidCache] = None, |
| ) -> Tensor: |
| attn_out = self.norm_1(self.attn(x, freqs_cis, step_idx, mask, past_key_values)) |
| x = attn_out + x |
| x = self.norm_2(self.mlp(x)) + x |
| return x |
|
|
|
|
| class RavenForCausalLM(RavenPreTrainedModel, GenerationMixin): |
|
|
| def __init__( |
| self, |
| config: RavenConfig, |
| ) -> None: |
| super().__init__(config) |
| self.config = config |
|
|
| |
| prelude = torch.nn.ModuleList(SandwichBlock(config, layer_id=i) for i in range(config.n_layers_in_prelude)) |
| adapter = torch.nn.Linear(config.n_embd * 2, config.n_embd, bias=config.bias) |
| core_block = torch.nn.ModuleList( |
| SandwichBlock(config, layer_id=i + config.n_layers_in_prelude) |
| for i in range(config.n_layers_in_recurrent_block) |
| ) |
| o = config.n_layers_in_prelude + config.n_layers_in_recurrent_block * config.mean_recurrence |
| coda = torch.nn.ModuleList(SandwichBlock(config, layer_id=i + o) for i in range(config.n_layers_in_coda)) |
|
|
| self.transformer = torch.nn.ModuleDict( |
| dict( |
| wte=torch.nn.Embedding(config.padded_vocab_size, config.n_embd), |
| prelude=prelude, |
| adapter=adapter, |
| core_block=core_block, |
| coda=coda, |
| ln_f=RMSNorm(config.n_embd, eps=config.norm_eps), |
| ) |
| ) |
| self.emb_scale = config.init_values["embed_scale"] |
| |
| self.lm_head = torch.nn.Linear(config.n_embd, config.padded_vocab_size, bias=False) |
| if self.config.tie_embeddings: |
| self.tie_weights() |
| |
| self.rotary_emb = Olmo2RotaryEmbedding(config=config) |
|
|
| def get_input_embeddings(self): |
| return self.transformer.wte |
|
|
| def get_output_embeddings(self): |
| return self.lm_head |
|
|
|
|
| def compile_mask( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| past_key_values: Optional[ValidCache] = None, |
| pad_token_id=65509, |
| ) -> Optional[BlockMask]: |
| batch_size, seq_len = input_ids.shape[0], input_ids.shape[1] |
|
|
| |
| if attention_mask is None and (input_ids == pad_token_id).sum() == 0: |
| return None |
|
|
| if past_key_values is not None and seq_len == 1: |
| return None |
|
|
| |
| cache_len = past_key_values.get_seq_length() if past_key_values is not None else 0 |
| kv_length = cache_len + seq_len |
|
|
| if attention_mask is None: |
|
|
| def mask_mod(b, h, q_idx, kv_idx): |
| return q_idx >= kv_idx & (input_ids[b, kv_idx] != pad_token_id) |
| else: |
|
|
| def mask_mod(b, h, q_idx, kv_idx): |
| return (q_idx >= kv_idx) & (input_ids[b, kv_idx] != pad_token_id) & attention_mask[b, q_idx, kv_idx] |
|
|
| kv_length = past_key_values.get_seq_length() if past_key_values is not None else seq_len |
| if kv_length == 0: |
| kv_length = seq_len |
| block_mask = create_block_mask( |
| mask_mod, |
| B=batch_size, |
| H=None, |
| Q_LEN=seq_len, |
| KV_LEN=kv_length, |
| device=input_ids.device, |
| ) |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
| |
|
|
| |
| |
| |
|
|
| |
| |
| |
| |
| |
| |
|
|
| |
|
|
| |
| block_mask = create_block_mask( |
| mask_mod, |
| B=batch_size, |
| H=None, |
| Q_LEN=seq_len, |
| KV_LEN=kv_length, |
| device=input_ids.device, |
| ) |
|
|
| return block_mask |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| input_embeds: Optional[torch.Tensor] = None, |
| input_states: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| labels: Optional[torch.Tensor] = None, |
| num_steps: Optional[torch.Tensor] = None, |
| past_key_values: Optional[ValidCache] = None, |
| output_details: dict = { |
| "return_logits": True, |
| "return_latents": True, |
| "return_head": False, |
| "return_stats": False, |
| }, |
| use_cache: bool = False, |
| cache_position: Optional[torch.Tensor] = None, |
| init_scale: float = 1.0, |
| **kwargs, |
| ) -> CausalLMOutputRecurrentLatents: |
| |
| if position_ids is None and cache_position is None: |
| position_ids = torch.arange(input_ids.shape[1], device=self.device).unsqueeze(0) |
| elif cache_position is not None: |
| position_ids = cache_position.unsqueeze(0) |
|
|
| if input_embeds is None: |
| input_embeds = self.transformer.wte(input_ids) |
|
|
| if self.emb_scale != 1: |
| input_embeds = input_embeds * self.emb_scale |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = HuginnDynamicCache() |
|
|
| prepared_attn_mask = None |
| block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) |
|
|
| freqs_cis = self.rotary_emb(input_embeds, position_ids) |
|
|
| |
| for block in self.transformer.prelude: |
| block_idx += 1 |
| input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values) |
|
|
| |
| x, num_steps_no_grad, num_steps_with_grad, xk, block_idx = self.iterate_forward( |
| input_embeds, |
| input_states, |
| freqs_cis, |
| block_idx, |
| prepared_attn_mask, |
| past_key_values, |
| num_steps, |
| init_scale, |
| ) |
| latent_states = x.clone().detach() |
|
|
| |
| block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) |
| for block in self.transformer.coda: |
| block_idx -= 1 |
| x = block(x, freqs_cis, block_idx, prepared_attn_mask, past_key_values) |
| x = self.transformer.ln_f(x) |
|
|
| |
| if labels is not None: |
| logits = self.lm_head(x).float() |
| loss = torch.nn.functional.cross_entropy( |
| logits.view(-1, logits.shape[-1]), labels.view(-1), ignore_index=-100 |
| ) |
| log_ppl = loss.clone().detach().exp() |
| else: |
| logits = self.lm_head(x) |
| loss, log_ppl = torch.as_tensor(0.0), torch.as_tensor(0.0) |
|
|
| return CausalLMOutputRecurrentLatents( |
| loss=loss, |
| log_ppl=log_ppl, |
| logits=logits if output_details["return_logits"] else None, |
| past_key_values=past_key_values, |
| hidden_states=x if output_details["return_head"] else None, |
| latent_states=latent_states if output_details["return_latents"] else None, |
| stats=self.get_stats(logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad) |
| if output_details["return_stats"] |
| else None, |
| ) |
|
|
| @torch._dynamo.disable(recursive=False) |
| def iterate_forward( |
| self, |
| input_embeds: torch.Tensor, |
| input_states: torch.Tensor, |
| freqs_cis, |
| block_idx: torch.Tensor, |
| mask: Optional[BlockMask], |
| past_key_values: Optional[ValidCache] = None, |
| num_steps: Optional[torch.Tensor] = None, |
| init_scale: float = 1.0, |
| ): |
| x = xk = self.initialize_state(input_embeds, scale=init_scale) if input_states is None else input_states.clone() |
| if num_steps is None: |
| num_steps_no_grad, num_steps_with_grad = self.randomized_iteration_sampler() |
| elif hasattr(num_steps, "__len__") and len(num_steps) > 1: |
| num_steps_no_grad, num_steps_with_grad = num_steps |
| else: |
| num_steps_no_grad, num_steps_with_grad = num_steps, torch.tensor(0) if not x.is_meta else 0 |
|
|
| with torch.no_grad(): |
| |
| |
| |
| |
| for no_grad_step in range(num_steps_no_grad): |
| xk = x |
| x, block_idx = self.core_block_forward( |
| xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, no_grad_step |
| ) |
|
|
| for grad_step in range(num_steps_with_grad): |
| xk = x |
| x, block_idx = self.core_block_forward( |
| xk, input_embeds, freqs_cis, mask, past_key_values, block_idx, num_steps_no_grad + grad_step |
| ) |
| return x, num_steps_no_grad, num_steps_with_grad, xk.detach(), block_idx |
|
|
| def core_block_forward( |
| self, |
| x, |
| input_embeds, |
| freqs_cis, |
| mask: Optional[BlockMask], |
| past_key_values, |
| block_idx: torch.Tensor, |
| current_step: int | Tensor, |
| ): |
| x = self._maybe_inject_noise(x, current_step) |
| x = self.transformer.adapter(torch.cat([x, input_embeds.to(x.device)], dim=-1)) |
| for block in self.transformer.core_block: |
| block_idx += 1 |
| x = block(x, freqs_cis, block_idx, mask, past_key_values) |
| return x, block_idx |
|
|
| @torch.no_grad() |
| def iterate_one_step( |
| self, |
| input_embeds, |
| input_states, |
| position_ids: Optional[torch.Tensor] = None, |
| cache_position: Optional[torch.Tensor] = None, |
| block_idx: torch.Tensor = torch.tensor(0, dtype=torch.long), |
| attention_mask: Optional[BlockMask] = None, |
| past_key_values: Optional[ValidCache] = None, |
| current_step: int = 0, |
| ): |
| if position_ids is None and cache_position is None: |
| freqs_cis = self.freqs_cis[:, : input_embeds.shape[1]] |
| elif position_ids is not None: |
| freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze()) |
| elif cache_position is not None: |
| freqs_cis = self.freqs_cis[:, cache_position] |
| x, block_idx = self.core_block_forward( |
| input_states, |
| input_embeds, |
| freqs_cis, |
| attention_mask, |
| past_key_values, |
| block_idx, |
| current_step=current_step, |
| ) |
| return x, block_idx, current_step + 1 |
|
|
| def predict_from_latents( |
| self, |
| latents, |
| attention_mask: Optional[BlockMask] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| cache_position: Optional[torch.Tensor] = None, |
| past_key_values: Optional[ValidCache] = None, |
| ): |
| if position_ids is None and cache_position is None: |
| freqs_cis = self.freqs_cis[:, : latents.shape[1]] |
| elif position_ids is not None: |
| freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze()) |
| elif cache_position is not None: |
| freqs_cis = self.freqs_cis[:, cache_position] |
| x = self.transformer.ln_f(latents) |
| |
| block_idx = torch.tensor(0, device=torch.device("cpu"), dtype=torch.long) |
| for block in self.transformer.coda: |
| block_idx -= 1 |
| x = block(x, freqs_cis, block_idx, attention_mask, past_key_values) |
| x = self.transformer.ln_f(x) |
|
|
| logits = self.lm_head(x).float() |
|
|
| return CausalLMOutputRecurrentLatents( |
| loss=torch.as_tensor(0.0), |
| log_ppl=torch.as_tensor(0.0), |
| logits=logits, |
| past_key_values=past_key_values, |
| latent_states=x, |
| ) |
|
|
| def embed_inputs( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| position_ids: Optional[torch.Tensor] = None, |
| past_key_values: Optional[ValidCache] = None, |
| use_cache: bool = False, |
| cache_position: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> tuple[torch.Tensor, torch.Tensor]: |
| |
| if position_ids is None and cache_position is None: |
| freqs_cis = self.freqs_cis[:, : input_ids.shape[1]] |
| elif position_ids is not None: |
| freqs_cis = self.freqs_cis.index_select(1, position_ids.squeeze()) |
| elif cache_position is not None: |
| freqs_cis = self.freqs_cis[:, cache_position] |
|
|
| input_embeds = self.transformer.wte(input_ids) |
| prepared_attn_mask = self.compile_mask(input_ids, attention_mask) |
|
|
| if self.emb_scale != 1: |
| input_embeds = input_embeds * self.emb_scale |
|
|
| if use_cache and past_key_values is None: |
| past_key_values = HuginnDynamicCache() |
|
|
| block_idx = torch.tensor(-1, device=torch.device("cpu"), dtype=torch.long) |
| |
| for block in self.transformer.prelude: |
| block_idx += 1 |
| input_embeds = block(input_embeds, freqs_cis, block_idx, prepared_attn_mask, past_key_values) |
| return input_embeds, block_idx |
|
|
| @torch._dynamo.disable(recursive=False) |
| def randomized_iteration_sampler(self) -> tuple[torch.Tensor, torch.Tensor]: |
| """Outputs are long tensors so that they can be passed through compiled functions""" |
| t = max(self.config.mean_recurrence - self.config.mean_backprop_depth, 0) |
| s = self.config.mean_backprop_depth |
| if torch.rand((1,)).is_meta: |
| |
| |
| return t, s |
| if self.training: |
| sigma = 0.5 |
| mu = math.log(t + s) - (sigma**2 / 2) |
| rate = torch.zeros((1,)).log_normal_(mean=mu, std=sigma) |
| p = torch.poisson(torch.tensor([rate], dtype=torch.float)) + 1 |
| n = torch.clamp(p - s, min=0) |
| k = torch.as_tensor(torch.minimum(torch.as_tensor(s), p)) |
| else: |
| n, k = torch.as_tensor(self.config.mean_recurrence), torch.as_tensor(0) |
|
|
| return n.to(dtype=torch.long), k.to(dtype=torch.long) |
|
|
| def initialize_state(self, input_embeds, scale: float = 1.0): |
| x = torch.randn_like(input_embeds) |
| std = self.config.init_values["std"] * scale |
| if std > 0: |
| torch.nn.init.trunc_normal_(x, mean=0.0, std=std, a=-3 * std, b=3 * std) |
| if self.emb_scale != 1: |
| x = x * self.emb_scale |
| else: |
| x.zero_() |
| return x |
|
|
| def _maybe_inject_noise(self, x, current_step, renorm=False): |
| if self.config.test_time_noise > 0: |
| n = self.config.test_time_noise * self.config.init_values["std"] * self.emb_scale |
| if self.config.test_time_noise_type == "geom": |
| step1 = torch.as_tensor(current_step + 1, device=x.device) |
| x = x * (1 - n / step1) + torch.randn_like(x) * n / step1 |
| elif self.config.test_time_noise_type == "sqrt": |
| step1sqrt = torch.as_tensor(current_step + 1, device=x.device).sqrt() |
| x = x * (1 - n / step1sqrt) + torch.randn_like(x) * n / step1sqrt |
| elif self.config.test_time_noise_type == "line": |
| noise = max(n, (self.config.mean_recurrence - current_step) / self.config.mean_recurrence) |
| x = x * (1 - noise) + torch.randn_like(x) * noise |
| elif self.config.test_time_noise_type == "chi": |
| noise = 2 * torch.rand(1, device=x.device, dtype=x.dtype) * n |
| x = x * (1 - noise) + torch.randn_like(x) * noise |
| elif self.config.test_time_noise_type == "fixed": |
| x = x * (1 - n) + torch.randn_like(x) * n |
| else: |
| raise ValueError() |
|
|
| if renorm: |
| x = self.transformer.core_block[-1].norm_4(x) |
| return x |
|
|
| def prepare_inputs_for_generation( |
| self, |
| input_ids: torch.Tensor, |
| past_key_values: Optional[Cache] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| inputs_embeds: Optional[torch.FloatTensor] = None, |
| cache_position: Optional[torch.Tensor] = None, |
| cache_lookup_strategy: str = "full", |
| **kwargs, |
| ): |
| model_inputs = {} |
| model_inputs["cache_position"] = cache_position |
| current_input_length = input_ids.shape[1] |
|
|
| if past_key_values is not None: |
| if not isinstance(past_key_values, (HuginnDynamicCache, HuginnStaticCache)): |
| assert past_key_values.get_seq_length() == 0 |
| |
| if isinstance(past_key_values, StaticCache): |
| past_key_values = HuginnStaticCache( |
| max_length=getattr(self.generation_config, "max_length", self.config.block_size), |
| max_num_steps=4 + kwargs.get("num_steps", self.config.mean_recurrence) * 4, |
| num_heads=self.config.num_key_value_heads, |
| hidden_dim=self.config.n_embd // self.config.num_attention_heads, |
| dtype=torch.bfloat16, |
| device=input_ids.device, |
| lookup_strategy=cache_lookup_strategy, |
| ) |
| else: |
| past_key_values = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy) |
| model_inputs["past_key_values"] = past_key_values if kwargs["use_cache"] else None |
| input_ids = input_ids[:, cache_position] |
|
|
| model_inputs["input_ids"] = input_ids.clone(memory_format=torch.contiguous_format) |
| if cache_position is None: |
| position_ids = torch.arange(current_input_length)[None, :].to(input_ids.device) |
| model_inputs["position_ids"] = position_ids[:, -current_input_length:].clone( |
| memory_format=torch.contiguous_format |
| ) |
|
|
| |
| for key, value in kwargs.items(): |
| if key not in model_inputs: |
| model_inputs[key] = value |
| return model_inputs |
|
|
| @torch.no_grad() |
| def generate(self, *args, **kwargs): |
| """Dispatcher - use HF generate in all normal cases.""" |
| self.generation_config = args[1] if len(args) > 1 else self.generation_config |
| if any(k in kwargs for k in ("criterion", "exit_threshold")): |
| |
| return self.generate_with_adaptive_compute(*args, **kwargs) |
| elif "continuous_compute" in kwargs: |
| |
| return self.generate_minimal(*args, **kwargs) |
| else: |
| return super().generate(*args, **kwargs) |
|
|
| @torch.no_grad() |
| def _prep_generate_args( |
| self, |
| input_ids: torch.Tensor, |
| generation_config: Optional[GenerationConfig] = None, |
| cache_lookup_strategy: str = "full", |
| model_kwargs: dict = {}, |
| ): |
| |
| if generation_config is None: |
| generation_config: GenerationConfig = self.generation_config |
| if "max_new_tokens" in model_kwargs: |
| max_new_tokens = model_kwargs["max_new_tokens"] |
| if "max_length" in model_kwargs: |
| max_new_tokens = min(max_new_tokens, model_kwargs["max_length"] - input_ids.shape[1]) |
| else: |
| max_length = model_kwargs.get("max_length", generation_config.max_length) |
| max_new_tokens = max_length - input_ids.shape[1] |
|
|
| if "cache_implementation" not in model_kwargs or model_kwargs["cache_implementation"] == "dynamic": |
| model_kwargs["past_key_values"] = HuginnDynamicCache(lookup_strategy=cache_lookup_strategy) |
| else: |
| model_kwargs["past_key_values"] = HuginnStaticCache( |
| max_length=max_length, |
| max_num_steps=4 + model_kwargs.get("num_steps", self.config.mean_recurrence) * 4, |
| num_heads=self.config.num_key_value_heads, |
| hidden_dim=self.config.n_embd // self.config.num_attention_heads, |
| batch_size=input_ids.shape[0], |
| dtype=torch.bfloat16, |
| device=input_ids.device, |
| lookup_strategy=cache_lookup_strategy, |
| ) |
| model_kwargs["use_cache"] = True |
| model_kwargs = self._get_initial_cache_position(input_ids, model_kwargs) |
| return model_kwargs, generation_config, max_new_tokens |
|
|
| @torch.no_grad() |
| def generate_minimal( |
| self, |
| input_ids: torch.Tensor, |
| generation_config: Optional[GenerationConfig] = None, |
| tokenizer=None, |
| streamer=None, |
| continuous_compute=False, |
| init_scale: float = 1.0, |
| cache_lookup_strategy: str = "full", |
| **model_kwargs, |
| ) -> Union[torch.Tensor, dict[str, Any]]: |
| """Minimal single-sequence generation. Template for more complicated generate tasks""" |
| model_kwargs, generation_config, max_new_tokens = self._prep_generate_args( |
| input_ids, generation_config, cache_lookup_strategy |
| ) |
| stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device) |
| unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |
|
|
| |
| if continuous_compute: |
| embedded_inputs, _ = self.embed_inputs(input_ids) |
| model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale) |
|
|
| |
| batch_size = input_ids.shape[0] |
| for _ in range(max_new_tokens): |
| |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| outputs = self(**model_inputs, init_scale=init_scale) |
|
|
| |
| next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32, device=input_ids.device) |
| next_token = self._sample_next_token(next_token_logits, generation_config) |
|
|
| |
| input_ids = torch.cat([input_ids, next_token], dim=-1) |
|
|
| if streamer: |
| streamer.put(next_token.cpu()) |
|
|
| |
| model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs) |
| if continuous_compute: |
| model_kwargs["input_states"] = outputs.latent_states[:, -1:, :] |
|
|
| if stop_tokens is not None: |
| for i in range(batch_size): |
| if unfinished_sequences[i] and next_token[i, 0].item() in stop_tokens: |
| unfinished_sequences[i] = 0 |
| if "stopping_criteria" in model_kwargs: |
| unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None) |
| if unfinished_sequences.max() == 0: |
| break |
|
|
| if streamer: |
| streamer.end() |
|
|
| if generation_config.return_dict_in_generate: |
| return GenerateDecoderOnlyOutput( |
| sequences=input_ids, |
| scores=None, |
| logits=None, |
| attentions=None, |
| hidden_states=None, |
| past_key_values=model_kwargs.get("past_key_values"), |
| ) |
| return input_ids |
|
|
| @torch.no_grad() |
| def generate_with_adaptive_compute( |
| self, |
| input_ids: torch.Tensor, |
| generation_config: Optional[GenerationConfig] = None, |
| tokenizer=None, |
| streamer=None, |
| continuous_compute=False, |
| criterion="none", |
| exit_threshold: Union[str, float, int] = "auto", |
| init_scale: float = 1.0, |
| cache_lookup_strategy: str = "full", |
| **model_kwargs, |
| ) -> Union[torch.Tensor, GenerateDecoderOnlyOutput]: |
| """ |
| Generate tokens with adaptive compute. This is NOT the most efficient implementation. |
| For batches, on each token, we iterate until the entire batch finishes. |
| """ |
| model_kwargs, generation_config, max_new_tokens = self._prep_generate_args( |
| input_ids, generation_config, cache_lookup_strategy, model_kwargs |
| ) |
| max_steps = model_kwargs.get("num_steps", self.config.mean_recurrence) |
| stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device) |
| logit_type = dict(copy=True, dtype=torch.float32, device=input_ids.device) |
| batch_size = input_ids.shape[0] |
| compute_steps = [] |
|
|
| |
| if continuous_compute: |
| embedded_inputs, _ = self.embed_inputs(input_ids) |
| model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale) |
|
|
| |
| unfinished_sequences = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) |
|
|
| |
| for _ in range(max_new_tokens): |
| |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| aux_inputs = { |
| k: model_inputs[k] for k in ["cache_position", "past_key_values", "attention_mask"] if k in model_inputs |
| } |
| embedded_inputs, block_idx = self.embed_inputs(model_inputs["input_ids"], **aux_inputs) |
| current_latents = ( |
| self.initialize_state(embedded_inputs, scale=init_scale) |
| if not continuous_compute |
| else model_kwargs["input_states"] |
| ) |
|
|
| |
| exit_values_per_seq = [[] for _ in range(batch_size)] |
| compute_steps_per_seq = [0] * batch_size |
| exit_reached = torch.zeros(batch_size, dtype=torch.bool, device=input_ids.device) |
|
|
| |
| if criterion == "entropy-diff": |
| entropy = torch.ones(batch_size, device=input_ids.device) * 100.0 |
| exit_threshold = 1e-3 if exit_threshold == "auto" else float(exit_threshold) |
| elif criterion == "latent-diff": |
| exit_threshold = 0.03 if exit_threshold == "auto" else float(exit_threshold) |
| elif "kl" in criterion: |
| V = self.config.padded_vocab_size |
| log_probs = ((1 / V) * torch.ones(batch_size, V, dtype=torch.float, device=input_ids.device)).log() |
| if criterion == "minp-kl": |
| exit_threshold = 1e-6 if exit_threshold == "auto" else float(exit_threshold) |
| else: |
| exit_threshold = 5e-4 if exit_threshold == "auto" else float(exit_threshold) |
| elif criterion == "argmax-stability": |
| stable_for_n_steps = torch.zeros(batch_size, dtype=torch.long, device=input_ids.device) |
| current_argmax = torch.ones(batch_size, dtype=torch.long, device=input_ids.device) * -1 |
| exit_threshold = 5 if exit_threshold == "auto" else int(exit_threshold) |
| elif criterion == "none": |
| exit_threshold = 1.0 if exit_threshold == "auto" else float(exit_threshold) |
| else: |
| raise ValueError("Invalid adaptive compute strategy.") |
|
|
| next_token_logits = None |
|
|
| |
| for compute_step in range(max_steps): |
| prev_latents = current_latents.clone() |
| current_latents, block_idx, _ = self.iterate_one_step( |
| embedded_inputs, |
| current_latents, |
| block_idx=block_idx, |
| **aux_inputs, |
| current_step=compute_step, |
| ) |
|
|
| if _ > 0: |
| |
| if criterion == "entropy-diff": |
| prev_entropy = entropy |
| outputs = self.predict_from_latents(current_latents, **aux_inputs) |
| logits: torch.Tensor = outputs.logits |
| probs = F.softmax(logits[:, -1, :], dim=-1) |
| entropy = -torch.sum(probs * torch.log(probs + 1e-10), dim=-1) |
| exit_values = (entropy - prev_entropy).abs() |
| elif criterion == "latent-diff": |
| norm_diff = (prev_latents - current_latents).norm(dim=-1) / current_latents.norm(dim=-1) |
| exit_values = norm_diff.mean(dim=-1) |
| elif "kl" in criterion: |
| outputs = self.predict_from_latents(current_latents, **aux_inputs) |
| logits: torch.Tensor = outputs.logits |
| prev_log_probs = log_probs |
| if criterion == "minp-kl": |
| probs = F.softmax(logits[:, -1, :].float(), dim=-1) |
| max_probs = probs.max(dim=-1, keepdim=True)[0] |
| probs_mask = probs < (0.1 * max_probs) |
| masked_probs = probs.clone() |
| masked_probs[probs_mask] = 1 / V |
| probs = masked_probs / masked_probs.sum(dim=-1, keepdim=True) |
| log_probs = probs.log() |
| else: |
| log_probs = F.log_softmax(logits[:, -1, :].float(), dim=-1) |
| exit_values = F.kl_div(log_probs, prev_log_probs, reduction="none", log_target=True).sum(dim=-1) |
| elif criterion == "argmax-stability": |
| prev_argmax = current_argmax |
| outputs = self.predict_from_latents(current_latents, **aux_inputs) |
| logits: torch.Tensor = outputs.logits |
| current_argmax = logits[:, -1, :].argmax(dim=-1) |
| stable_for_n_steps = torch.where( |
| current_argmax == prev_argmax, stable_for_n_steps + 1, torch.zeros_like(stable_for_n_steps) |
| ) |
| exit_values = stable_for_n_steps |
| elif criterion == "none": |
| exit_values = torch.ones(batch_size, device=input_ids.device) * 2.0 * exit_threshold |
|
|
| |
| for i in range(batch_size): |
| if not exit_reached[i] and unfinished_sequences[i].bool(): |
| exit_values_per_seq[i].append(exit_values[i].item()) |
|
|
| |
| new_exits = ( |
| exit_values < exit_threshold |
| if criterion != "argmax-stability" |
| else exit_values >= exit_threshold |
| ) |
| new_exits = new_exits & ~exit_reached & unfinished_sequences.bool() |
|
|
| if new_exits.any(): |
| exit_reached = exit_reached | new_exits |
| if criterion == "latent-diff": |
| |
| |
| outputs = self.predict_from_latents(current_latents, **aux_inputs) |
| logits: torch.Tensor = outputs.logits |
| if next_token_logits is None: |
| next_token_logits = logits[:, -1, :].to(**logit_type) |
| else: |
| for i in range(batch_size): |
| if new_exits[i]: |
| next_token_logits[i] = logits[i, -1, :].to(**logit_type) |
| for i in range(batch_size): |
| if new_exits[i]: |
| compute_steps_per_seq[i] = compute_step + 1 |
|
|
| |
| if (exit_reached | ~unfinished_sequences.bool()).all(): |
| break |
| |
| else: |
| outputs = self.predict_from_latents(current_latents, **aux_inputs) |
|
|
| |
| if next_token_logits is None: |
| next_token_logits = outputs.logits[:, -1, :].to(**logit_type) |
| else: |
| for i in range(batch_size): |
| if not exit_reached[i] and unfinished_sequences[i].bool(): |
| next_token_logits[i] = outputs.logits[i, -1, :].to(**logit_type) |
| compute_steps_per_seq[i] = max_steps |
|
|
| |
| if continuous_compute: |
| model_kwargs["input_states"] = current_latents[:, -1:, :] |
|
|
| |
| compute_steps.append([compute_steps_per_seq, exit_values_per_seq]) |
|
|
| |
| next_token = self._sample_next_token(next_token_logits, generation_config) |
|
|
| |
| input_ids = torch.cat([input_ids, next_token], dim=-1) |
|
|
| if streamer: |
| streamer.put(next_token.cpu()) |
|
|
| |
| model_kwargs = self._update_model_kwargs_for_generation(outputs, model_kwargs) |
|
|
| |
| for i in range(batch_size): |
| if ( |
| unfinished_sequences[i].bool() |
| and stop_tokens is not None |
| and next_token[i, 0].item() in stop_tokens |
| ): |
| unfinished_sequences[i] = 0 |
|
|
| |
| if "stopping_criteria" in model_kwargs: |
| unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None) |
|
|
| |
| if unfinished_sequences.max() == 0: |
| break |
|
|
| if streamer: |
| streamer.end() |
|
|
| if generation_config.return_dict_in_generate: |
| return GenerateDecoderOnlyOutput( |
| sequences=input_ids, |
| scores=compute_steps, |
| logits=None, |
| attentions=None, |
| hidden_states=None, |
| past_key_values=model_kwargs.get("past_key_values"), |
| ) |
| return input_ids |
|
|
| def _get_stops(self, generation_config, tokenizer, model_kwargs): |
| stop_tokens = {65504, 65505, 65508} |
| if generation_config.eos_token_id is not None: |
| stop_tokens.add(generation_config.eos_token_id) |
| if "stopping_criteria" in model_kwargs and tokenizer is None: |
| tokenizer = model_kwargs["stopping_criteria"][0].tokenizer |
| if hasattr(generation_config, "stop_strings") and tokenizer and generation_config.stop_strings: |
| for s in generation_config.stop_strings: |
| token_id = tokenizer(s, add_special_tokens=False)["input_ids"][0] |
| stop_tokens.add(token_id) |
| return torch.tensor(list(stop_tokens)) |
|
|
| def _sample_next_token(self, next_token_logits, generation_config): |
| """Helper function to sample the next token.""" |
| if generation_config.do_sample: |
| if generation_config.temperature: |
| next_token_logits = next_token_logits.float() / generation_config.temperature |
|
|
| probs = F.softmax(next_token_logits, dim=-1) |
|
|
| |
| if generation_config.top_k: |
| top_k_values, _ = torch.topk(probs, generation_config.top_k, dim=-1) |
| min_values = top_k_values[:, -1].unsqueeze(-1).expand_as(probs) |
| probs = torch.where(probs < min_values, torch.zeros_like(probs), probs) |
|
|
| |
| if generation_config.top_p: |
| sorted_probs, sorted_indices = torch.sort(probs, descending=True, dim=-1) |
| cumulative_probs = torch.cumsum(sorted_probs, dim=-1) |
|
|
| |
| remove_indices = cumulative_probs > generation_config.top_p |
| remove_indices[:, 0] = False |
|
|
| |
| mask = torch.zeros_like(probs, dtype=torch.bool) |
| for i in range(probs.shape[0]): |
| mask[i, sorted_indices[i, remove_indices[i]]] = True |
|
|
| probs = torch.where(mask, torch.zeros_like(probs), probs) |
|
|
| |
| if generation_config.min_p: |
| max_probs = probs.max(dim=-1, keepdim=True)[0] |
| min_p_threshold = generation_config.min_p * max_probs |
| probs = torch.where(probs < min_p_threshold, torch.zeros_like(probs), probs) |
|
|
| |
| probs = probs / probs.sum(dim=-1, keepdim=True).clamp(min=1e-10) |
|
|
| |
| return torch.multinomial(probs, num_samples=1) |
| else: |
| return torch.argmax(next_token_logits, dim=-1, keepdim=True) |
|
|
| @torch.no_grad() |
| def generate_speculative( |
| self, |
| input_ids: torch.Tensor, |
| generation_config: Optional[GenerationConfig] = None, |
| tokenizer=None, |
| streamer=None, |
| continuous_compute=False, |
| init_scale: float = 1.0, |
| cache_lookup_strategy: str = "full", |
| draft_steps=32, |
| lookahead_for_draft=8, |
| verification_threshold=1, |
| num_steps: int = 32, |
| **model_kwargs, |
| ) -> Union[torch.Tensor, dict[str, Any]]: |
| """Batched speculative decoding with per-sequence acceptance.""" |
| assert lookahead_for_draft > 0 |
| pad_id = 65509 |
| model_kwargs, generation_config, max_new_tokens = self._prep_generate_args( |
| input_ids, generation_config, cache_lookup_strategy, model_kwargs |
| ) |
| stop_tokens = self._get_stops(generation_config, tokenizer, model_kwargs).to(input_ids.device) |
| unfinished_sequences = torch.ones(input_ids.shape[0], dtype=torch.long, device=input_ids.device) |
|
|
| |
| if continuous_compute: |
| embedded_inputs, _ = self.embed_inputs(input_ids) |
| model_kwargs["input_states"] = self.initialize_state(embedded_inputs, scale=init_scale) |
|
|
| tokens_generated = 0 |
| |
| if model_kwargs["past_key_values"].get_seq_length() == 0: |
| model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs) |
| outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale) |
| next_token = self._sample_next_token( |
| outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32), generation_config |
| ) |
| input_ids = torch.cat([input_ids, next_token], dim=-1) |
| tokens_generated += 1 |
| if streamer: |
| streamer.put(next_token.cpu()) |
| model_kwargs["cache_position"] = torch.as_tensor( |
| [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device |
| ) |
| if continuous_compute: |
| model_kwargs["input_states"] = outputs.latent_states[:, -1:, :] |
|
|
| |
| batch_size, prefix_seq_len = input_ids.shape[0], input_ids.shape[1] |
| accepted_tokens = [] |
|
|
| while tokens_generated < max_new_tokens: |
| |
| drafted_inputs = input_ids.clone() |
| current_len = input_ids.shape[1] |
|
|
| for _ in range(lookahead_for_draft): |
| model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs) |
| outputs = self(**model_inputs, num_steps=draft_steps, init_scale=init_scale) |
| next_token_logits = outputs.logits[:, -1, :].to(copy=True, dtype=torch.float32) |
| next_token = self._sample_next_token(next_token_logits, generation_config) |
| drafted_inputs = torch.cat([drafted_inputs, next_token], dim=-1) |
| model_kwargs["cache_position"] = model_kwargs["cache_position"][-1:] + 1 |
| if continuous_compute: |
| model_kwargs["input_states"] = outputs.latent_states[:, -1:, :] |
|
|
| model_kwargs["past_key_values"].clear_last_k_entries(lookahead_for_draft) |
|
|
| |
| model_kwargs["cache_position"] = torch.arange( |
| current_len - 1, current_len + lookahead_for_draft - 1, device=input_ids.device |
| ) |
| model_inputs = self.prepare_inputs_for_generation(drafted_inputs, **model_kwargs) |
| outputs = self(**model_inputs, num_steps=num_steps, init_scale=init_scale) |
| verified_next_token_preds = outputs.logits.argmax(dim=-1) |
|
|
| if verification_threshold >= 1: |
| mismatched_tokens = ( |
| verified_next_token_preds[:, -lookahead_for_draft:] != drafted_inputs[:, current_len:] |
| ) |
| not_all_matched, first_mismatch = torch.max(mismatched_tokens, dim=1) |
| else: |
| verified_logits = outputs.logits[:, -lookahead_for_draft:, :] |
| verified_probs = F.softmax(verified_logits, dim=-1) |
| drafted_token_probs = torch.gather( |
| verified_probs, -1, drafted_inputs[:, current_len:].unsqueeze(-1) |
| ).squeeze(-1) |
| max_probs = verified_probs.max(dim=-1)[0] |
| verification_passed = drafted_token_probs >= verification_threshold * max_probs |
| not_all_matched, first_mismatch = torch.max(~verification_passed, dim=1) |
|
|
| |
| acceptance_lengths = torch.where(not_all_matched, first_mismatch, lookahead_for_draft) |
|
|
| |
| next_tokens_batch = [] |
| for i in range(batch_size): |
| seq_acceptance = acceptance_lengths[i].item() |
| if not_all_matched[i] and seq_acceptance < lookahead_for_draft: |
| |
| accepted_part = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance] |
| final_token_logits = outputs.logits[i : i + 1, seq_acceptance, :].to(copy=True, dtype=torch.float32) |
| final_token = self._sample_next_token(final_token_logits, generation_config) |
| seq_tokens = torch.cat([accepted_part, final_token], dim=-1) if seq_acceptance > 0 else final_token |
| else: |
| |
| seq_tokens = drafted_inputs[i : i + 1, current_len : current_len + seq_acceptance] |
| next_tokens_batch.append(seq_tokens) |
|
|
| |
| if not_all_matched.any(): |
| min_first_mismatch = first_mismatch.min().item() |
| model_inputs["past_key_values"].clear_last_k_entries(lookahead_for_draft - min_first_mismatch - 1) |
|
|
| |
| batch_accepted_counts = [tokens.shape[1] for tokens in next_tokens_batch] |
| max_len = max(batch_accepted_counts) |
| padded_tokens = [ |
| torch.cat( |
| [ |
| tokens, |
| pad_id * torch.ones((1, max_len - tokens.shape[1]), dtype=tokens.dtype, device=tokens.device), |
| ], |
| dim=-1, |
| ) |
| if tokens.shape[1] < max_len |
| else tokens |
| for tokens in next_tokens_batch |
| ] |
| next_tokens = torch.cat(padded_tokens, dim=0) |
| input_ids = torch.cat([input_ids, next_tokens], dim=-1) |
|
|
| accepted_tokens.append(batch_accepted_counts) |
| tokens_generated += max(batch_accepted_counts) |
|
|
| if streamer: |
| streamer.put(next_tokens_batch[0].cpu()) |
|
|
| model_kwargs["cache_position"] = torch.as_tensor( |
| [model_inputs["past_key_values"].get_seq_length()], device=input_ids.device |
| ) |
| if continuous_compute: |
| model_kwargs["input_states"] = outputs.latent_states[:, -1:, :] |
|
|
| |
| if stop_tokens is not None: |
| for i in range(batch_size): |
| if unfinished_sequences[i] and torch.isin(next_tokens_batch[i], stop_tokens).any(): |
| unfinished_sequences[i] = 0 |
| if "stopping_criteria" in model_kwargs: |
| unfinished_sequences = unfinished_sequences & ~model_kwargs["stopping_criteria"](input_ids, None) |
| if unfinished_sequences.max() == 0: |
| break |
|
|
| if streamer: |
| streamer.end() |
|
|
| |
| if stop_tokens is not None: |
| for i in range(batch_size): |
| stop_positions = torch.isin(input_ids[i, prefix_seq_len:], stop_tokens).nonzero() |
| if len(stop_positions) > 0: |
| input_ids[i, prefix_seq_len + stop_positions[0].item() + 1 :] = pad_id |
| |
| non_pad_mask = input_ids != pad_id |
| last_real_token = non_pad_mask.any(dim=0).nonzero() |
| if len(last_real_token) > 0: |
| input_ids = input_ids[:, : last_real_token[-1].item() + 1] |
|
|
| if generation_config.return_dict_in_generate: |
| return GenerateDecoderOnlyOutput( |
| sequences=input_ids, |
| scores=accepted_tokens, |
| logits=None, |
| attentions=None, |
| hidden_states=None, |
| past_key_values=model_kwargs.get("past_key_values"), |
| ) |
| return input_ids |
|
|
| def get_stats(self, logits, x, latent_states, xk, input_embeds, num_steps_no_grad, num_steps_with_grad): |
| probs = torch.softmax(logits.float(), dim=-1) |
| prob_entropy = torch.where(probs > 0, -probs * probs.log(), 0).sum(dim=-1) |
| residual_diff = (x - latent_states).norm(dim=-1) |
| rel_residual = residual_diff / latent_states.norm(dim=-1) |
| stats = { |
| "entropy": prob_entropy, |
| "residual_diff": residual_diff, |
| "rel_residual": rel_residual, |
| "num_steps_no_grad": num_steps_no_grad, |
| "num_steps_with_grad": num_steps_with_grad, |
| } |
| return stats |
|
|
|
|
| |
|
|
| from transformers import AutoConfig, AutoModel, AutoModelForCausalLM |
|
|
| |
| RavenConfig.register_for_auto_class() |
|
|
| RavenForCausalLM.register_for_auto_class("AutoModel") |
| RavenForCausalLM.register_for_auto_class("AutoModelForCausalLM") |
|
|
| |
| AutoConfig.register("huginn_raven", RavenConfig) |
| AutoModel.register(RavenConfig, RavenForCausalLM) |
| AutoModelForCausalLM.register(RavenConfig, RavenForCausalLM) |
|
|