| | """Transformers model implementation for NeuroCoder remote-code loading.""" |
| |
|
| | from __future__ import annotations |
| |
|
| | import math |
| | from typing import Any |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import Tensor, nn |
| | from transformers import PreTrainedModel |
| | from transformers.modeling_outputs import CausalLMOutputWithPast |
| |
|
| | try: |
| | from .configuration_neurocoder import NeuroCoderConfig |
| | except Exception: |
| | from configuration_neurocoder import NeuroCoderConfig |
| |
|
| |
|
| | class RMSNorm(nn.Module): |
| | def __init__(self, hidden_size: int, eps: float = 1e-6) -> None: |
| | super().__init__() |
| | self.eps = eps |
| | self.weight = nn.Parameter(torch.ones(hidden_size)) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | rms = x.pow(2).mean(-1, keepdim=True) |
| | return x * torch.rsqrt(rms + self.eps) * self.weight |
| |
|
| |
|
| | class SelfAttention(nn.Module): |
| | def __init__(self, config: NeuroCoderConfig) -> None: |
| | super().__init__() |
| | self.num_heads = config.num_heads |
| | self.head_dim = config.head_dim |
| | self.scale = self.head_dim ** -0.5 |
| | self.qkv = nn.Linear(config.hidden_size, config.hidden_size * 3) |
| | self.out = nn.Linear(config.hidden_size, config.hidden_size) |
| |
|
| | def forward( |
| | self, |
| | x: Tensor, |
| | past_key_value: tuple[Tensor, Tensor] | None = None, |
| | attention_mask: Tensor | None = None, |
| | use_cache: bool = False, |
| | ) -> tuple[Tensor, tuple[Tensor, Tensor] | None]: |
| | bsz, seq_len, hidden = x.shape |
| | qkv = self.qkv(x) |
| | q, k, v = qkv.chunk(3, dim=-1) |
| |
|
| | def shape_heads(t: Tensor) -> Tensor: |
| | return t.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2) |
| |
|
| | q = shape_heads(q) |
| | k = shape_heads(k) |
| | v = shape_heads(v) |
| |
|
| | if past_key_value is not None: |
| | past_k, past_v = past_key_value |
| | if past_k is not None and past_v is not None: |
| | k = torch.cat([past_k, k], dim=2) |
| | v = torch.cat([past_v, v], dim=2) |
| |
|
| | present = (k, v) if use_cache else None |
| | key_len = k.shape[-2] |
| | past_len = key_len - seq_len |
| |
|
| | attn = torch.matmul(q, k.transpose(-2, -1)) * self.scale |
| | if seq_len > 1 or past_len > 0: |
| | q_positions = torch.arange( |
| | past_len, |
| | past_len + seq_len, |
| | device=x.device, |
| | ).unsqueeze(-1) |
| | k_positions = torch.arange(key_len, device=x.device).unsqueeze(0) |
| | causal_mask = (k_positions <= q_positions).unsqueeze(0).unsqueeze(0) |
| | attn = attn.masked_fill(~causal_mask, float("-inf")) |
| | if attention_mask is not None: |
| | |
| | |
| | key_mask = attention_mask[:, -key_len:].to(dtype=torch.bool).unsqueeze(1).unsqueeze(1) |
| | attn = attn.masked_fill(~key_mask, float("-inf")) |
| |
|
| | probs = F.softmax(attn, dim=-1) |
| | out = torch.matmul(probs, v) |
| | out = out.transpose(1, 2).contiguous().view(bsz, seq_len, hidden) |
| | return self.out(out), present |
| |
|
| |
|
| | class DenseFFN(nn.Module): |
| | def __init__(self, config: NeuroCoderConfig) -> None: |
| | super().__init__() |
| | inner = config.hidden_size * config.ffn_multiplier |
| | self.gate = nn.Linear(config.hidden_size, inner) |
| | self.up = nn.Linear(config.hidden_size, inner) |
| | self.down = nn.Linear(inner, config.hidden_size) |
| |
|
| | def forward(self, x: Tensor) -> Tensor: |
| | return self.down(F.silu(self.gate(x)) * self.up(x)) |
| |
|
| |
|
| | class MoEFeedForward(nn.Module): |
| | def __init__(self, config: NeuroCoderConfig) -> None: |
| | super().__init__() |
| | self.num_experts = config.num_experts |
| | self.top_k = config.router_top_k |
| | self.capacity_factor_train = config.capacity_factor_train |
| | self.capacity_factor_infer = config.capacity_factor_infer |
| | self.router = nn.Linear(config.hidden_size, config.num_experts, bias=False) |
| | self.experts = nn.ModuleList([DenseFFN(config) for _ in range(config.num_experts)]) |
| |
|
| | def forward(self, x: Tensor) -> tuple[Tensor, Tensor, Tensor]: |
| | bsz, seq_len, hidden = x.shape |
| | x_flat = x.reshape(-1, hidden) |
| | tokens = x_flat.shape[0] |
| |
|
| | logits = self.router(x_flat) |
| | probs = F.softmax(logits, dim=-1) |
| | top_vals, top_idx = torch.topk(probs, k=self.top_k, dim=-1) |
| |
|
| | capacity_factor = self.capacity_factor_train if self.training else self.capacity_factor_infer |
| | capacity = max(1, math.ceil(capacity_factor * tokens / self.num_experts)) |
| |
|
| | output = torch.zeros_like(x_flat) |
| | expert_load = [] |
| |
|
| | for expert_id in range(self.num_experts): |
| | expert = self.experts[expert_id] |
| | assigned_indices = [] |
| | assigned_weights = [] |
| | for rank in range(self.top_k): |
| | mask = top_idx[:, rank] == expert_id |
| | idx = torch.nonzero(mask, as_tuple=False).squeeze(-1) |
| | if idx.numel() == 0: |
| | continue |
| | weights = top_vals[idx, rank] |
| | assigned_indices.append(idx) |
| | assigned_weights.append(weights) |
| |
|
| | if not assigned_indices: |
| | expert_load.append(0.0) |
| | continue |
| |
|
| | token_indices = torch.cat(assigned_indices, dim=0) |
| | token_weights = torch.cat(assigned_weights, dim=0) |
| | if token_indices.numel() > capacity: |
| | token_indices = token_indices[:capacity] |
| | token_weights = token_weights[:capacity] |
| |
|
| | expert_in = x_flat[token_indices] |
| | expert_out = expert(expert_in) |
| | output[token_indices] += expert_out * token_weights.unsqueeze(-1) |
| | expert_load.append(float(token_indices.numel() / max(tokens, 1))) |
| |
|
| | load_tensor = torch.tensor(expert_load, device=x.device) |
| | mean_prob = probs.mean(dim=0) |
| | aux_loss = self.num_experts * torch.sum(mean_prob * load_tensor) |
| | z_loss = torch.mean(torch.logsumexp(logits, dim=-1) ** 2) |
| | return output.reshape(bsz, seq_len, hidden), aux_loss, z_loss |
| |
|
| |
|
| | class TransformerBlock(nn.Module): |
| | def __init__(self, config: NeuroCoderConfig, use_moe: bool) -> None: |
| | super().__init__() |
| | self.norm1 = RMSNorm(config.hidden_size) |
| | self.norm2 = RMSNorm(config.hidden_size) |
| | self.attn = SelfAttention(config) |
| | self.ffn = MoEFeedForward(config) if use_moe else DenseFFN(config) |
| | self.use_moe = use_moe |
| |
|
| | def forward( |
| | self, |
| | x: Tensor, |
| | past_key_value: tuple[Tensor, Tensor] | None = None, |
| | attention_mask: Tensor | None = None, |
| | use_cache: bool = False, |
| | ) -> tuple[Tensor, Tensor, Tensor, tuple[Tensor, Tensor] | None]: |
| | attn_out, present = self.attn( |
| | self.norm1(x), |
| | past_key_value=past_key_value, |
| | attention_mask=attention_mask, |
| | use_cache=use_cache, |
| | ) |
| | x = x + attn_out |
| | aux_loss = torch.tensor(0.0, device=x.device) |
| | z_loss = torch.tensor(0.0, device=x.device) |
| | ffn_input = self.norm2(x) |
| | if self.use_moe: |
| | ffn_out, aux_loss, z_loss = self.ffn(ffn_input) |
| | else: |
| | ffn_out = self.ffn(ffn_input) |
| | x = x + ffn_out |
| | return x, aux_loss, z_loss, present |
| |
|
| |
|
| | class NeuroCoderForCausalLM(PreTrainedModel): |
| | config_class = NeuroCoderConfig |
| | base_model_prefix = "neurocoder" |
| | _no_split_modules = ["TransformerBlock", "MoEFeedForward"] |
| | _supports_cache_class = False |
| |
|
| | def __init__(self, config: NeuroCoderConfig) -> None: |
| | super().__init__(config) |
| | self.token_embed = nn.Embedding(config.vocab_size, config.hidden_size) |
| | self.pos_embed = nn.Embedding(config.context_length, config.hidden_size) |
| | self.layers = nn.ModuleList( |
| | [ |
| | TransformerBlock(config, use_moe=((idx + 1) % config.moe_every_n_layers == 0)) |
| | for idx in range(config.num_layers) |
| | ] |
| | ) |
| | self.norm = RMSNorm(config.hidden_size) |
| | self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) |
| | self.lm_head.weight = self.token_embed.weight |
| | self.post_init() |
| |
|
| | def get_input_embeddings(self) -> nn.Embedding: |
| | return self.token_embed |
| |
|
| | def set_input_embeddings(self, value: nn.Embedding) -> None: |
| | self.token_embed = value |
| |
|
| | def get_output_embeddings(self) -> nn.Linear: |
| | return self.lm_head |
| |
|
| | def set_output_embeddings(self, new_embeddings: nn.Linear) -> None: |
| | self.lm_head = new_embeddings |
| |
|
| | def prepare_inputs_for_generation( |
| | self, |
| | input_ids: Tensor, |
| | **kwargs: Any, |
| | ) -> dict[str, Any]: |
| | past_key_values = kwargs.get("past_key_values") |
| | has_past = False |
| | if past_key_values is not None and hasattr(past_key_values, "get_seq_length"): |
| | has_past = bool(past_key_values.get_seq_length() > 0) |
| | elif isinstance(past_key_values, tuple) and past_key_values: |
| | first = past_key_values[0] |
| | has_past = bool(first and first[0] is not None and first[1] is not None) |
| |
|
| | if has_past: |
| | input_ids = input_ids[:, -1:] |
| | return { |
| | "input_ids": input_ids, |
| | "attention_mask": kwargs.get("attention_mask"), |
| | "past_key_values": past_key_values, |
| | "use_cache": kwargs.get("use_cache", True), |
| | } |
| |
|
| | @staticmethod |
| | def _as_legacy_past_key_values( |
| | past_key_values: Any, |
| | num_layers: int, |
| | ) -> tuple[tuple[Tensor, Tensor] | None, ...]: |
| | if past_key_values is None: |
| | return tuple([None] * num_layers) |
| |
|
| | if hasattr(past_key_values, "to_legacy_cache"): |
| | past_key_values = past_key_values.to_legacy_cache() |
| |
|
| | if isinstance(past_key_values, list): |
| | past_key_values = tuple(past_key_values) |
| | if isinstance(past_key_values, tuple): |
| | return past_key_values |
| |
|
| | key_cache = getattr(past_key_values, "key_cache", None) |
| | value_cache = getattr(past_key_values, "value_cache", None) |
| | if isinstance(key_cache, list) and isinstance(value_cache, list): |
| | pairs: list[tuple[Tensor, Tensor] | None] = [] |
| | for idx in range(num_layers): |
| | if idx < len(key_cache) and idx < len(value_cache): |
| | key = key_cache[idx] |
| | value = value_cache[idx] |
| | if key is not None and value is not None: |
| | pairs.append((key, value)) |
| | continue |
| | pairs.append(None) |
| | return tuple(pairs) |
| |
|
| | return tuple([None] * num_layers) |
| |
|
| | def _reorder_cache( |
| | self, |
| | past_key_values: tuple[tuple[Tensor, Tensor], ...] | list[tuple[Tensor, Tensor]], |
| | beam_idx: Tensor, |
| | ) -> tuple[tuple[Tensor, Tensor], ...]: |
| | reordered: list[tuple[Tensor, Tensor]] = [] |
| | for key, value in past_key_values: |
| | reordered.append((key.index_select(0, beam_idx), value.index_select(0, beam_idx))) |
| | return tuple(reordered) |
| |
|
| | def forward( |
| | self, |
| | input_ids: Tensor | None = None, |
| | attention_mask: Tensor | None = None, |
| | labels: Tensor | None = None, |
| | past_key_values: Any = None, |
| | use_cache: bool | None = None, |
| | **kwargs: Any, |
| | ) -> CausalLMOutputWithPast: |
| | if input_ids is None: |
| | raise ValueError("input_ids is required") |
| |
|
| | cache_enabled = bool(self.config.use_cache if use_cache is None else use_cache) |
| | past = self._as_legacy_past_key_values(past_key_values, len(self.layers)) |
| | bsz, seq_len = input_ids.shape |
| | past_len = 0 |
| | for entry in past: |
| | if ( |
| | entry is not None |
| | and isinstance(entry, tuple) |
| | and len(entry) == 2 |
| | and entry[0] is not None |
| | and entry[1] is not None |
| | ): |
| | past_len = int(entry[0].shape[2]) |
| | break |
| | pos = torch.arange( |
| | past_len, |
| | past_len + seq_len, |
| | device=input_ids.device, |
| | ).unsqueeze(0).expand(bsz, seq_len) |
| | pos = pos.clamp_max(self.config.context_length - 1) |
| | x = self.token_embed(input_ids) + self.pos_embed(pos) |
| | aux_loss = torch.tensor(0.0, device=input_ids.device) |
| | z_loss = torch.tensor(0.0, device=input_ids.device) |
| | present_key_values: list[tuple[Tensor, Tensor]] = [] |
| |
|
| | for layer_idx, layer in enumerate(self.layers): |
| | layer_past = past[layer_idx] if layer_idx < len(past) else None |
| | x, layer_aux, layer_z, layer_present = layer( |
| | x, |
| | past_key_value=layer_past, |
| | attention_mask=attention_mask, |
| | use_cache=cache_enabled, |
| | ) |
| | aux_loss = aux_loss + layer_aux |
| | z_loss = z_loss + layer_z |
| | if cache_enabled and layer_present is not None: |
| | present_key_values.append(layer_present) |
| |
|
| | x = self.norm(x) |
| | logits = self.lm_head(x) |
| |
|
| | loss = None |
| | if labels is not None: |
| | loss = F.cross_entropy( |
| | logits.view(-1, logits.size(-1)), |
| | labels.view(-1), |
| | ignore_index=-100, |
| | ) |
| | loss = loss + 0.01 * aux_loss + 0.001 * z_loss |
| |
|
| | return CausalLMOutputWithPast( |
| | loss=loss, |
| | logits=logits, |
| | past_key_values=tuple(present_key_values) if cache_enabled else None, |
| | ) |
| |
|