import math import torch import torch.nn as nn import torch.nn.functional as F from transformers import PreTrainedModel, GenerationMixin from transformers.activations import ACT2FN from transformers.modeling_outputs import CausalLMOutputWithPast, BaseModelOutputWithPast from typing import Optional, Tuple, List, Union import inspect from dataclasses import dataclass try: # Used when dynamically loaded by HF Hub (`trust_remote_code=True`) from .configuration_model import HybridModelConfig except ImportError: # Used when running local scripts directly from configuration_model import HybridModelConfig def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) t = torch.arange(end, device=freqs.device) freqs = torch.outer(t, freqs).float() freqs_cis = torch.polar(torch.ones_like(freqs), freqs) return freqs_cis def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor): ndim = x.ndim assert 0 <= 1 < ndim assert freqs_cis.shape == (x.shape[1], x.shape[-1]) shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] return freqs_cis.view(*shape) def apply_rotary_emb( xq: torch.Tensor, xk: torch.Tensor, q_freqs_cis: torch.Tensor, k_freqs_cis: torch.Tensor, ) -> Tuple[torch.Tensor, torch.Tensor]: xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) q_freqs = reshape_for_broadcast(q_freqs_cis, xq_) k_freqs = reshape_for_broadcast(k_freqs_cis, xk_) xq_out = torch.view_as_real(xq_ * q_freqs).flatten(xq.ndim - 1) xk_out = torch.view_as_real(xk_ * k_freqs).flatten(xk.ndim - 1) return xq_out.type_as(xq), xk_out.type_as(xk) class RMSNorm(nn.Module): def __init__(self, hidden_size, eps=1e-6): super().__init__() self.weight = nn.Parameter(torch.ones(hidden_size)) self.variance_epsilon = eps def forward(self, hidden_states): input_dtype = hidden_states.dtype hidden_states = hidden_states.to(torch.float32) variance = hidden_states.pow(2).mean(-1, keepdim=True) hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon) return self.weight * hidden_states.to(input_dtype) # ================================ # MHC (Multi-Head Connections) Implementation # ================================ def sinkhorn_knopp( logits: torch.Tensor, *, tmax: int = 20, eps: float = 1e-8, clamp_min: float = 0.0, ) -> torch.Tensor: log_m = logits.float() log_m = log_m - log_m.amax(dim=(-2, -1), keepdim=True) for _ in range(tmax): log_m = log_m - torch.logsumexp(log_m, dim=-1, keepdim=True) log_m = log_m - torch.logsumexp(log_m, dim=-2, keepdim=True) m = torch.exp(log_m) if clamp_min is not None and clamp_min > 0: m = m.clamp_min(clamp_min) m = m / (m.sum(dim=-1, keepdim=True) + eps) m = m / (m.sum(dim=-2, keepdim=True) + eps) return m @dataclass(frozen=True) class MhcMappings: h_pre: torch.Tensor h_post: torch.Tensor h_res: torch.Tensor class MhcProjector(nn.Module): def __init__( self, *, n_streams: int, hidden_dim: int, tmax: int = 20, alpha_init: float = 0.01, rmsnorm_eps: float = 1e-6, ): super().__init__() self.n = int(n_streams) self.c = int(hidden_dim) self.tmax = int(tmax) flat_dim = self.n * self.c self.rmsnorm = RMSNorm(flat_dim, eps=rmsnorm_eps) self.phi_pre = nn.Parameter(torch.empty(flat_dim, self.n)) self.phi_post = nn.Parameter(torch.empty(flat_dim, self.n)) self.phi_res = nn.Parameter(torch.empty(flat_dim, self.n * self.n)) self.b_pre = nn.Parameter(torch.zeros(self.n)) self.b_post = nn.Parameter(torch.zeros(self.n)) self.b_res = nn.Parameter(torch.zeros(self.n, self.n)) self.alpha_pre = nn.Parameter(torch.tensor(float(alpha_init))) self.alpha_post = nn.Parameter(torch.tensor(float(alpha_init))) self.alpha_res = nn.Parameter(torch.tensor(float(alpha_init))) self.reset_parameters() def reset_parameters(self) -> None: std = 0.02 nn.init.normal_(self.phi_pre, mean=0.0, std=std) nn.init.normal_(self.phi_post, mean=0.0, std=std) nn.init.normal_(self.phi_res, mean=0.0, std=std) nn.init.zeros_(self.b_pre) nn.init.zeros_(self.b_post) nn.init.zeros_(self.b_res) self.init_gpt2_equivalence() @torch.no_grad() def init_gpt2_equivalence(self, *, offdiag_bias: float = -20.0, alpha: float = 0.0) -> None: self.phi_pre.zero_() self.phi_post.zero_() self.phi_res.zero_() self.alpha_pre.fill_(alpha) self.alpha_post.fill_(alpha) self.alpha_res.fill_(alpha) p = 1.0 / float(self.n) logit_p = math.log(p / (1.0 - p)) if p not in (0.0, 1.0) else 0.0 self.b_pre.fill_(logit_p) self.b_post.zero_() self.b_res.fill_(offdiag_bias) self.b_res.diagonal().fill_(0.0) def forward(self, x_stream: torch.Tensor) -> MhcMappings: b, t, n, c = x_stream.shape x_flat = x_stream.reshape(b * t, n * c) x_flat = self.rmsnorm(x_flat) h_pre_tilde = self.alpha_pre * (x_flat @ self.phi_pre) + self.b_pre h_post_tilde = self.alpha_post * (x_flat @ self.phi_post) + self.b_post h_res_dyn = x_flat @ self.phi_res h_res_tilde = self.alpha_res * h_res_dyn.reshape(b * t, n, n) + self.b_res h_pre = torch.sigmoid(h_pre_tilde).reshape(b, t, n) h_post = (2.0 * torch.sigmoid(h_post_tilde)).reshape(b, t, n) h_res = sinkhorn_knopp(h_res_tilde.reshape(b, t, n, n), tmax=self.tmax) return MhcMappings(h_pre=h_pre, h_post=h_post, h_res=h_res) def stream_weighted_sum(x_stream: torch.Tensor, weights: torch.Tensor) -> torch.Tensor: if weights.dtype != x_stream.dtype: weights = weights.to(dtype=x_stream.dtype) return torch.einsum("btn,btnc->btc", weights, x_stream) def stream_mix(x_stream: torch.Tensor, h_res: torch.Tensor) -> torch.Tensor: if h_res.dtype != x_stream.dtype: h_res = h_res.to(dtype=x_stream.dtype) return torch.einsum("btij,btjc->btic", h_res, x_stream) def stream_write(y: torch.Tensor, h_post: torch.Tensor) -> torch.Tensor: if h_post.dtype != y.dtype: h_post = h_post.to(dtype=y.dtype) return h_post.unsqueeze(-1) * y.unsqueeze(-2) def mhc_update(x_stream: torch.Tensor, *, h_post: torch.Tensor, h_res: torch.Tensor, y: torch.Tensor) -> torch.Tensor: return stream_mix(x_stream, h_res) + stream_write(y, h_post) # ================================ class HybridMLAAttention(nn.Module): def __init__(self, config: HybridModelConfig, layer_idx: int): super().__init__() self.config = config self.layer_idx = layer_idx self.d_model = config.hidden_size self.num_head = config.num_attention_heads self.d_head = self.d_model // self.num_head self.d_embed = config.hidden_size self.d_c = config.kv_lora_rank self.d_c1 = config.q_lora_rank self.d_rotate = config.qk_rope_head_dim self.dropout_rate = config.attention_dropout self.sliding_window = config.sliding_window if layer_idx % 2 == 0 else None self.DKV_proj = nn.Linear(self.d_embed, self.d_c, bias=False) self.DQ_proj = nn.Linear(self.d_embed, self.d_c1, bias=False) self.UQ_proj = nn.Linear(self.d_c1, self.d_model, bias=False) self.UK_proj = nn.Linear(self.d_c, self.d_model, bias=False) self.UV_proj = nn.Linear(self.d_c, self.d_model, bias=False) self.RQ_proj = nn.Linear(self.d_c1, self.num_head * self.d_rotate, bias=False) self.RK_proj = nn.Linear(self.d_embed, self.d_rotate, bias=False) self.o_proj = nn.Linear(self.d_model, self.d_model, bias=False) self.dropout = nn.Dropout(p=self.dropout_rate) self.scaler = float(1.0 / math.sqrt(self.d_head + self.d_rotate)) def forward(self, hidden_states, attention_mask=None, past_key_value=None, freqs_cis=None, use_cache=False): batch_size, seq_len, _ = hidden_states.size() start_pos = past_key_value[0].size(1) if past_key_value is not None else 0 C_Q = self.DQ_proj(hidden_states) Q_state = self.UQ_proj(C_Q) Q_rotate = self.RQ_proj(C_Q) C_KV = self.DKV_proj(hidden_states) K_rotate = self.RK_proj(hidden_states) if past_key_value is not None: C_KV_cache, K_rotate_cache = past_key_value C_KV = torch.cat([C_KV_cache, C_KV], dim=1) K_rotate = torch.cat([K_rotate_cache, K_rotate], dim=1) present_key_value = (C_KV, K_rotate) if use_cache else None actual_kv_len = C_KV.size(1) K_state = self.UK_proj(C_KV) V_state = self.UV_proj(C_KV) Q_state = Q_state.view(batch_size, seq_len, self.num_head, self.d_head) K_state = K_state.view(batch_size, actual_kv_len, self.num_head, self.d_head) V_state = V_state.view(batch_size, actual_kv_len, self.num_head, self.d_head) Q_rotate = Q_rotate.view(batch_size, seq_len, self.num_head, self.d_rotate) K_rotate = K_rotate.unsqueeze(2).expand(-1, -1, self.num_head, -1) if freqs_cis is not None: q_freqs = freqs_cis[start_pos : start_pos + seq_len] k_freqs = freqs_cis[:actual_kv_len] Q_rotate, K_rotate = apply_rotary_emb(Q_rotate, K_rotate, q_freqs, k_freqs) Q_state = torch.cat([Q_state, Q_rotate], dim=-1) K_state = torch.cat([K_state, K_rotate], dim=-1) Q_state = Q_state * self.scaler Q_state = Q_state.transpose(1, 2) K_state = K_state.transpose(1, 2) V_state = V_state.transpose(1, 2) att_matrix = torch.matmul(Q_state, K_state.transpose(-1, -2)) if attention_mask is not None: att_matrix = att_matrix + attention_mask if self.sliding_window is not None and actual_kv_len > 1: window_mask = torch.ones(seq_len, actual_kv_len, dtype=torch.bool, device=hidden_states.device) window_mask = torch.tril(window_mask, diagonal=actual_kv_len - seq_len) window_mask = torch.triu(window_mask, diagonal=actual_kv_len - seq_len + 1 - self.sliding_window) window_mask = ~window_mask att_matrix.masked_fill_(window_mask[None, None, :, :], torch.finfo(att_matrix.dtype).min) att_score = F.softmax(att_matrix, dim=-1, dtype=torch.float32).to(Q_state.dtype) att_score = self.dropout(att_score) att_output = torch.matmul(att_score, V_state) att_output = att_output.transpose(1, 2).contiguous().view(batch_size, seq_len, self.num_head * self.d_head) att_output = self.o_proj(att_output) return att_output, None, present_key_value class HybridMLP(nn.Module): def __init__(self, config): super().__init__() self.hidden_size = config.hidden_size self.intermediate_size = config.intermediate_size self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False) self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False) self.act_fn = ACT2FN[config.hidden_act] def forward(self, x): return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) class HybridDecoderLayer(nn.Module): def __init__(self, config: HybridModelConfig, layer_idx: int): super().__init__() self.hidden_size = config.hidden_size self.self_attn = HybridMLAAttention(config=config, layer_idx=layer_idx) self.mlp = HybridMLP(config) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) # MHC modules self.mhc_attn = MhcProjector( n_streams=config.mhc_num_streams, hidden_dim=config.hidden_size, tmax=config.mhc_sinkhorn_iters, alpha_init=config.mhc_alpha_init, rmsnorm_eps=config.mhc_rmsnorm_eps, ) self.mhc_mlp = MhcProjector( n_streams=config.mhc_num_streams, hidden_dim=config.hidden_size, tmax=config.mhc_sinkhorn_iters, alpha_init=config.mhc_alpha_init, rmsnorm_eps=config.mhc_rmsnorm_eps, ) def forward(self, hidden_states, attention_mask=None, past_key_value=None, freqs_cis=None, use_cache=False): # hidden_states is x_stream: [B, T, n_streams, C] x_stream = hidden_states # Attention step maps_attn = self.mhc_attn(x_stream) x_in = stream_weighted_sum(x_stream, maps_attn.h_pre) x_in = self.input_layernorm(x_in) attn_out, _, present_key_value = self.self_attn( hidden_states=x_in, attention_mask=attention_mask, past_key_value=past_key_value, freqs_cis=freqs_cis, use_cache=use_cache, ) x_stream = mhc_update(x_stream, h_post=maps_attn.h_post, h_res=maps_attn.h_res, y=attn_out) # MLP step maps_mlp = self.mhc_mlp(x_stream) x_in2 = stream_weighted_sum(x_stream, maps_mlp.h_pre) x_in2 = self.post_attention_layernorm(x_in2) mlp_out = self.mlp(x_in2) x_stream = mhc_update(x_stream, h_post=maps_mlp.h_post, h_res=maps_mlp.h_res, y=mlp_out) return x_stream, present_key_value class HybridPreTrainedModel(PreTrainedModel): config_class = HybridModelConfig base_model_prefix = "model" supports_gradient_checkpointing = True _supports_cache_class = False # use legacy tuple KV cache, not DynamicCache def _init_weights(self, module): std = self.config.initializer_range if isinstance(module, nn.Linear): module.weight.data.normal_(mean=0.0, std=std) if module.bias is not None: module.bias.data.zero_() elif isinstance(module, nn.Embedding): module.weight.data.normal_(mean=0.0, std=std) if module.padding_idx is not None: module.weight.data[module.padding_idx].zero_() class HybridModel(HybridPreTrainedModel): def __init__(self, config: HybridModelConfig): super().__init__(config) self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) self.layers = nn.ModuleList([HybridDecoderLayer(config, layer_idx) for layer_idx in range(config.num_hidden_layers)]) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) freqs_cis = precompute_freqs_cis(config.qk_rope_head_dim, config.max_position_embeddings, config.rope_theta) self.register_buffer("freqs_cis", freqs_cis, persistent=False) # MHC Readout self.mhc_readout_logits = nn.Parameter(torch.zeros(config.mhc_num_streams)) self._init_readout() self.post_init() def _init_readout(self) -> None: with torch.no_grad(): if self.config.mhc_readout_init == "mean": self.mhc_readout_logits.zero_() else: self.mhc_readout_logits.fill_(-5.0) self.mhc_readout_logits[0] = 5.0 def _stream_init(self, hidden_states: torch.Tensor) -> torch.Tensor: b, t, c = hidden_states.shape n = self.config.mhc_num_streams if self.config.mhc_stream_init == "copy": return hidden_states.unsqueeze(-2).expand(b, t, n, c).contiguous() x_stream = hidden_states.new_zeros((b, t, n, c)) x_stream[:, :, 0, :] = hidden_states return x_stream def _readout(self, x_stream: torch.Tensor) -> torch.Tensor: w = torch.softmax(self.mhc_readout_logits, dim=0).to(dtype=x_stream.dtype) return torch.einsum("n,btnc->btc", w, x_stream) def forward( self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, use_cache=None, output_hidden_states=None, return_dict=None ): output_hidden_states = output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states use_cache = use_cache if use_cache is not None else self.config.use_cache return_dict = return_dict if return_dict is not None else self.config.use_return_dict batch_size, seq_length = input_ids.shape past_key_values_length = 0 if past_key_values is not None: # Convert DynamicCache (or any Cache object) to legacy tuple of tuples if not isinstance(past_key_values, tuple): if hasattr(past_key_values, "to_legacy_cache"): past_key_values = past_key_values.to_legacy_cache() else: past_key_values = None # An empty tuple means no real cached state yet (first generate() call) if past_key_values is not None and len(past_key_values) == 0: past_key_values = None if past_key_values is not None: past_key_values_length = past_key_values[0][0].shape[1] inputs_embeds = self.embed_tokens(input_ids) hidden_states = inputs_embeds kv_seq_len = seq_length + past_key_values_length causal_mask = torch.tril( torch.ones((seq_length, kv_seq_len), dtype=torch.bool, device=input_ids.device), diagonal=past_key_values_length ) if attention_mask is not None: attention_mask_expanded = attention_mask[:, None, None, :] == 1 else: attention_mask_expanded = True mask = causal_mask[None, None, :, :] & attention_mask_expanded extended_attention_mask = torch.where(mask, 0.0, torch.finfo(hidden_states.dtype).min) all_present_key_values = () if use_cache else None all_hidden_states = () if output_hidden_states else None x_stream = self._stream_init(hidden_states) for i, layer in enumerate(self.layers): if output_hidden_states: all_hidden_states += (self._readout(x_stream),) past_key_value = past_key_values[i] if past_key_values is not None else None x_stream, present_key_value = layer( x_stream, attention_mask=extended_attention_mask, past_key_value=past_key_value, freqs_cis=self.freqs_cis, use_cache=use_cache, ) if use_cache: all_present_key_values += (present_key_value,) hidden_states = self._readout(x_stream) hidden_states = self.norm(hidden_states) if output_hidden_states: all_hidden_states += (hidden_states,) if not return_dict: return tuple(v for v in [hidden_states, all_present_key_values, all_hidden_states] if v is not None) return BaseModelOutputWithPast( last_hidden_state=hidden_states, past_key_values=all_present_key_values, hidden_states=all_hidden_states, ) class HybridForCausalLM(HybridPreTrainedModel, GenerationMixin): def __init__(self, config): super().__init__(config) self.model = HybridModel(config) self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) self.post_init() def get_input_embeddings(self): return self.model.embed_tokens def set_input_embeddings(self, value): self.model.embed_tokens = value def get_output_embeddings(self): return self.lm_head def set_output_embeddings(self, new_embeddings): self.lm_head = new_embeddings def forward( self, input_ids=None, attention_mask=None, position_ids=None, past_key_values=None, labels=None, use_cache=None, output_hidden_states=None, return_dict=None ): return_dict = return_dict if return_dict is not None else self.config.use_return_dict outputs = self.model( input_ids=input_ids, attention_mask=attention_mask, position_ids=position_ids, past_key_values=past_key_values, use_cache=use_cache, output_hidden_states=output_hidden_states, return_dict=return_dict ) hidden_states = outputs[0] if not return_dict else outputs.last_hidden_state logits = self.lm_head(hidden_states) loss = None if labels is not None: shift_logits = logits[..., :-1, :].contiguous() shift_labels = labels[..., 1:].contiguous() loss_fct = nn.CrossEntropyLoss() loss = loss_fct(shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)) if not return_dict: output = (logits,) + outputs[1:] return (loss,) + output if loss is not None else output return CausalLMOutputWithPast( loss=loss, logits=logits, past_key_values=outputs.past_key_values if return_dict else None, hidden_states=outputs.hidden_states if return_dict else None, ) def prepare_inputs_for_generation(self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs): if past_key_values is not None: if hasattr(past_key_values, "get_seq_length"): past_length = past_key_values.get_seq_length() else: past_length = past_key_values[0][0].shape[1] if attention_mask is not None and attention_mask.shape[1] > input_ids.shape[1]: input_ids = input_ids[:, -1:] elif past_length < input_ids.shape[1]: input_ids = input_ids[:, past_length:] position_ids = kwargs.get("position_ids", None) if attention_mask is not None and position_ids is None: position_ids = attention_mask.long().cumsum(-1) - 1 position_ids.masked_fill_(attention_mask == 0, 1) if past_key_values: position_ids = position_ids[:, -input_ids.shape[1] :] model_inputs = {"input_ids": input_ids} model_inputs.update( { "position_ids": position_ids, "past_key_values": past_key_values, "use_cache": kwargs.get("use_cache"), "attention_mask": attention_mask, } ) return model_inputs def _reorder_cache(self, past_key_values, beam_idx): reordered_past = () for layer_past in past_key_values: reordered_past += ( tuple(past_state.index_select(0, beam_idx.to(past_state.device)) for past_state in layer_past), ) return reordered_past HybridModelConfig.register_for_auto_class() HybridForCausalLM.register_for_auto_class("AutoModelForCausalLM")