| import torch
|
| import torch.nn as nn
|
| import math
|
| from ._ops import ops
|
|
|
|
|
| def matmul_persistent(
|
| a: torch.Tensor, b: torch.Tensor, bias: torch.Tensor = None
|
| ) -> torch.Tensor:
|
| """
|
| Persistent matrix multiplication with optional bias.
|
|
|
| Args:
|
| a: Input tensor of shape (M, K)
|
| b: Input tensor of shape (K, N)
|
| bias: Optional bias tensor of shape (N,)
|
|
|
| Returns:
|
| Output tensor of shape (M, N)
|
| """
|
| assert a.shape[1] == b.shape[0], "Incompatible dimensions"
|
| assert a.dtype == b.dtype, "Incompatible dtypes"
|
| assert bias is None or bias.dim() == 1, "Bias must be 1D"
|
|
|
| M, K = a.shape
|
| K, N = b.shape
|
|
|
| c = torch.empty((M, N), device=a.device, dtype=a.dtype)
|
| ops.matmul_persistent(a, b, c, bias)
|
|
|
| return c
|
|
|
|
|
| def log_softmax(input: torch.Tensor, dim: int = -1) -> torch.Tensor:
|
| """
|
| Compute log_softmax using custom CUDA kernel.
|
|
|
| Args:
|
| input: Input tensor
|
| dim: Dimension along which to compute log_softmax (only -1 supported)
|
|
|
| Returns:
|
| Tensor with log_softmax applied
|
| """
|
| if dim != -1 and dim != input.ndim - 1:
|
| raise ValueError(
|
| "This implementation only supports log_softmax along the last dimension"
|
| )
|
|
|
| output = torch.empty_like(input)
|
| ops.log_softmax(input, output)
|
|
|
| return output
|
|
|
|
|
| def mean_dim(
|
| input: torch.Tensor, dim: int, keepdim: bool = False, dtype: torch.dtype = None
|
| ) -> torch.Tensor:
|
| """
|
| Compute mean along a single dimension.
|
|
|
| Args:
|
| input: Input tensor
|
| dim: Single dimension along which to compute mean
|
| keepdim: Whether to keep the reduced dimension
|
| dtype: Output dtype
|
|
|
| Returns:
|
| Tensor with mean values along specified dimension
|
| """
|
| assert input.is_cuda, "Input must be a CUDA tensor"
|
| assert -input.ndim <= dim < input.ndim, f"Invalid dimension {dim}"
|
|
|
| if dim < 0:
|
| dim = dim + input.ndim
|
|
|
| if dtype is None:
|
| if input.dtype in [torch.int8, torch.int16, torch.int32, torch.int64]:
|
| dtype = torch.float32
|
| else:
|
| dtype = input.dtype
|
|
|
| if input.dtype != dtype:
|
| input = input.to(dtype)
|
|
|
| shape = list(input.shape)
|
|
|
| if keepdim:
|
| output_shape = shape.copy()
|
| output_shape[dim] = 1
|
| else:
|
| output_shape = shape[:dim] + shape[dim + 1 :]
|
|
|
| output = torch.empty(output_shape, dtype=dtype, device=input.device)
|
| ops.mean_dim(input, output, dim)
|
|
|
| return output
|
|
|
|
|
|
|
| def mm_batch_invariant(a, b):
|
| return matmul_persistent(a, b)
|
|
|
|
|
| def addmm_batch_invariant(bias, a, b):
|
| return matmul_persistent(a, b, bias=bias)
|
|
|
|
|
| def _log_softmax_batch_invariant(input, dim, _half_to_float):
|
| assert not _half_to_float, "not implemented"
|
| return log_softmax(input, dim=dim)
|
|
|
|
|
| def mean_batch_invariant(input, dim, keepdim=False, dtype: torch.dtype = None):
|
| if len(dim) == 1:
|
| return mean_dim(input, dim[0], keepdim=keepdim, dtype=dtype)
|
| else:
|
|
|
| n_elems = 1
|
| for d in dim:
|
| n_elems *= input.shape[d]
|
| return torch.sum(input, dim=dim, keepdim=keepdim, dtype=torch.float32) / n_elems
|
|
|
| class BatchInvariantAttention(nn.Module):
|
| """
|
| Batch invariant multi-head attention implementation.
|
| Compatible with transformers library integration.
|
| """
|
|
|
| def __init__(self, config):
|
| super().__init__()
|
| self.config = config
|
| self.hidden_size = config.hidden_size
|
| self.num_heads = config.num_attention_heads
|
| self.head_dim = self.hidden_size // self.num_heads
|
| self.max_position_embeddings = getattr(config, "max_position_embeddings", 2048)
|
|
|
| if (self.head_dim * self.num_heads) != self.hidden_size:
|
| raise ValueError(
|
| f"hidden_size must be divisible by num_heads (got `hidden_size`: {self.hidden_size}"
|
| f" and `num_heads`: {self.num_heads})."
|
| )
|
|
|
|
|
| self.q_proj = nn.Linear(
|
| self.hidden_size, self.num_heads * self.head_dim, bias=False
|
| )
|
| self.k_proj = nn.Linear(
|
| self.hidden_size, self.num_heads * self.head_dim, bias=False
|
| )
|
| self.v_proj = nn.Linear(
|
| self.hidden_size, self.num_heads * self.head_dim, bias=False
|
| )
|
| self.o_proj = nn.Linear(
|
| self.num_heads * self.head_dim, self.hidden_size, bias=False
|
| )
|
|
|
| def forward(
|
| self,
|
| hidden_states: torch.Tensor,
|
| attention_mask: torch.Tensor = None,
|
| position_ids: torch.Tensor = None,
|
| past_key_value=None,
|
| output_attentions: bool = False,
|
| use_cache: bool = False,
|
| cache_position: torch.Tensor = None,
|
| **kwargs,
|
| ):
|
| batch_size, seq_len, _ = hidden_states.size()
|
|
|
|
|
| query_states = self._batch_invariant_linear(hidden_states, self.q_proj.weight)
|
| key_states = self._batch_invariant_linear(hidden_states, self.k_proj.weight)
|
| value_states = self._batch_invariant_linear(hidden_states, self.v_proj.weight)
|
|
|
|
|
| query_states = query_states.view(
|
| batch_size, seq_len, self.num_heads, self.head_dim
|
| ).transpose(1, 2)
|
| key_states = key_states.view(
|
| batch_size, seq_len, self.num_heads, self.head_dim
|
| ).transpose(1, 2)
|
| value_states = value_states.view(
|
| batch_size, seq_len, self.num_heads, self.head_dim
|
| ).transpose(1, 2)
|
|
|
|
|
| attn_weights = torch.matmul(
|
| query_states, key_states.transpose(2, 3)
|
| ) / math.sqrt(self.head_dim)
|
|
|
|
|
| if attention_mask is not None:
|
| attn_weights = attn_weights + attention_mask
|
|
|
|
|
| attn_weights_log = log_softmax(attn_weights, dim=-1)
|
| attn_weights = torch.exp(attn_weights_log)
|
|
|
|
|
| attn_output = torch.matmul(attn_weights, value_states)
|
|
|
|
|
| attn_output = attn_output.transpose(1, 2).contiguous()
|
| attn_output = attn_output.reshape(batch_size, seq_len, self.hidden_size)
|
| attn_output = self._batch_invariant_linear(attn_output, self.o_proj.weight)
|
|
|
| outputs = (attn_output,)
|
| if output_attentions:
|
| outputs += (attn_weights,)
|
| if use_cache:
|
| outputs += (past_key_value,)
|
|
|
| return outputs
|
|
|
| def _batch_invariant_linear(
|
| self, input_tensor: torch.Tensor, weight: torch.Tensor
|
| ) -> torch.Tensor:
|
| """Apply linear transformation using batch invariant matrix multiplication"""
|
| original_shape = input_tensor.shape
|
| input_2d = input_tensor.view(-1, original_shape[-1])
|
| output_2d = matmul_persistent(input_2d, weight.t())
|
| return output_2d.view(*original_shape[:-1], -1)
|
|
|
|
|
| class BatchInvariantMLP(nn.Module):
|
| """
|
| Batch invariant MLP implementation.
|
| """
|
|
|
| def __init__(self, config):
|
| super().__init__()
|
| self.config = config
|
| self.hidden_size = config.hidden_size
|
| self.intermediate_size = config.intermediate_size
|
|
|
| self.gate_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| self.up_proj = nn.Linear(self.hidden_size, self.intermediate_size, bias=False)
|
| self.down_proj = nn.Linear(self.intermediate_size, self.hidden_size, bias=False)
|
| self.act_fn = (
|
| nn.SiLU()
|
| )
|
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
|
| gate = self._batch_invariant_linear(x, self.gate_proj.weight)
|
| up = self._batch_invariant_linear(x, self.up_proj.weight)
|
|
|
|
|
| intermediate = self.act_fn(gate) * up
|
|
|
|
|
| output = self._batch_invariant_linear(intermediate, self.down_proj.weight)
|
| return output
|
|
|
| def _batch_invariant_linear(
|
| self, input_tensor: torch.Tensor, weight: torch.Tensor
|
| ) -> torch.Tensor:
|
| """Apply linear transformation using batch invariant matrix multiplication"""
|
| original_shape = input_tensor.shape
|
| input_2d = input_tensor.view(-1, original_shape[-1])
|
| output_2d = matmul_persistent(input_2d, weight.t())
|
| return output_2d.view(*original_shape[:-1], -1)
|
|
|
|
|
| class BatchInvariantRMSNorm(nn.Module):
|
| """
|
| Batch invariant RMS normalization implementation.
|
| """
|
|
|
| def __init__(self, hidden_size, eps=1e-6):
|
| super().__init__()
|
| self.weight = nn.Parameter(torch.ones(hidden_size))
|
| self.variance_epsilon = eps
|
|
|
| def forward(self, hidden_states):
|
| input_dtype = hidden_states.dtype
|
| hidden_states = hidden_states.to(torch.float32)
|
|
|
|
|
| variance = mean_dim(hidden_states.pow(2), dim=-1, keepdim=True)
|
| hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
|
|
| return self.weight * hidden_states.to(input_dtype)
|
|
|
|
|
|
|
| __all__ += ["BatchInvariantAttention", "BatchInvariantMLP", "BatchInvariantRMSNorm"] |