|
|
from typing import Optional, Union, Tuple
|
|
|
import numpy as np
|
|
|
from dataclasses import dataclass
|
|
|
import warnings
|
|
|
|
|
|
@dataclass
|
|
|
class ProjectionConfig:
|
|
|
"""Configuration for final projection layer"""
|
|
|
hidden_dim: int
|
|
|
vocab_size: int
|
|
|
use_bias: bool = True
|
|
|
use_fp16: bool = False
|
|
|
use_weight_tying: bool = False
|
|
|
dropout_rate: float = 0.0
|
|
|
initializer_range: float = 0.02
|
|
|
|
|
|
class FinalProjection:
|
|
|
"""
|
|
|
Optimized final projection layer for transformer models with support for:
|
|
|
- Weight tying with input embeddings
|
|
|
- Mixed precision (FP16/FP32)
|
|
|
- Memory-efficient computation
|
|
|
- Optimized matrix multiplication
|
|
|
- Bias fusion
|
|
|
"""
|
|
|
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: ProjectionConfig,
|
|
|
embedding_weights: Optional[np.ndarray] = None,
|
|
|
driver = None
|
|
|
):
|
|
|
"""
|
|
|
Initialize the final projection layer.
|
|
|
|
|
|
Args:
|
|
|
config: Projection configuration
|
|
|
embedding_weights: Optional weights for weight tying with input embeddings
|
|
|
driver: Optional hardware driver for optimized computation
|
|
|
"""
|
|
|
self.config = config
|
|
|
self.driver = driver
|
|
|
|
|
|
if config.use_weight_tying and embedding_weights is not None:
|
|
|
|
|
|
self.weight = embedding_weights.T
|
|
|
else:
|
|
|
|
|
|
self.weight = self._initialize_weights()
|
|
|
|
|
|
if config.use_bias:
|
|
|
self.bias = np.zeros(config.vocab_size, dtype=self._get_dtype())
|
|
|
else:
|
|
|
self.bias = None
|
|
|
|
|
|
|
|
|
self._setup_cache()
|
|
|
|
|
|
def _get_dtype(self) -> np.dtype:
|
|
|
"""Get the appropriate dtype based on configuration"""
|
|
|
return np.float16 if self.config.use_fp16 else np.float32
|
|
|
|
|
|
def _initialize_weights(self) -> np.ndarray:
|
|
|
"""Initialize projection weights"""
|
|
|
return np.random.normal(
|
|
|
0.0,
|
|
|
self.config.initializer_range,
|
|
|
(self.config.hidden_dim, self.config.vocab_size)
|
|
|
).astype(self._get_dtype())
|
|
|
|
|
|
def _setup_cache(self):
|
|
|
"""Setup computation cache for optimizations"""
|
|
|
self._cached_shapes = {}
|
|
|
if self.driver and hasattr(self.driver, 'prepare_projection'):
|
|
|
self._prepared_weight = self.driver.prepare_projection(self.weight)
|
|
|
if self.bias is not None:
|
|
|
self._prepared_bias = self.driver.prepare_bias(self.bias)
|
|
|
else:
|
|
|
self._prepared_weight = None
|
|
|
self._prepared_bias = None
|
|
|
|
|
|
def _apply_dropout(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
training: bool = False
|
|
|
) -> np.ndarray:
|
|
|
"""Apply dropout if configured"""
|
|
|
if training and self.config.dropout_rate > 0:
|
|
|
mask = np.random.binomial(
|
|
|
1,
|
|
|
1.0 - self.config.dropout_rate,
|
|
|
x.shape
|
|
|
).astype(self._get_dtype()) / (1.0 - self.config.dropout_rate)
|
|
|
return x * mask
|
|
|
return x
|
|
|
|
|
|
def _validate_input(self, x: np.ndarray):
|
|
|
"""Validate input tensor shape and type"""
|
|
|
if x.ndim != 3:
|
|
|
raise ValueError(
|
|
|
f"Expected 3D input tensor (batch, seq_len, hidden_dim), got shape {x.shape}"
|
|
|
)
|
|
|
if x.shape[-1] != self.config.hidden_dim:
|
|
|
raise ValueError(
|
|
|
f"Input hidden dimension {x.shape[-1]} doesn't match "
|
|
|
f"configured hidden_dim {self.config.hidden_dim}"
|
|
|
)
|
|
|
|
|
|
def _optimize_computation(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
batch_size: int,
|
|
|
seq_len: int
|
|
|
) -> np.ndarray:
|
|
|
"""Optimize computation based on input shape and hardware"""
|
|
|
shape_key = (batch_size, seq_len)
|
|
|
|
|
|
|
|
|
if shape_key in self._cached_shapes:
|
|
|
return self._cached_shapes[shape_key](x)
|
|
|
|
|
|
if self.driver and hasattr(self.driver, 'optimized_projection'):
|
|
|
|
|
|
compute_plan = self.driver.optimized_projection(
|
|
|
batch_size,
|
|
|
seq_len,
|
|
|
self._prepared_weight,
|
|
|
self._prepared_bias
|
|
|
)
|
|
|
self._cached_shapes[shape_key] = compute_plan
|
|
|
return compute_plan(x)
|
|
|
|
|
|
return None
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
x: np.ndarray,
|
|
|
training: bool = False
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Forward pass of the final projection layer.
|
|
|
|
|
|
Args:
|
|
|
x: Input tensor of shape (batch_size, seq_len, hidden_dim)
|
|
|
training: Whether in training mode (enables dropout)
|
|
|
|
|
|
Returns:
|
|
|
logits: Output logits of shape (batch_size, seq_len, vocab_size)
|
|
|
"""
|
|
|
self._validate_input(x)
|
|
|
|
|
|
|
|
|
x = x.astype(self._get_dtype())
|
|
|
|
|
|
|
|
|
x = self._apply_dropout(x, training)
|
|
|
|
|
|
batch_size, seq_len, _ = x.shape
|
|
|
|
|
|
|
|
|
optimized_result = self._optimize_computation(x, batch_size, seq_len)
|
|
|
if optimized_result is not None:
|
|
|
return optimized_result
|
|
|
|
|
|
|
|
|
if self.driver and self._prepared_weight is not None:
|
|
|
|
|
|
logits = self.driver.matmul(x, self._prepared_weight)
|
|
|
if self._prepared_bias is not None:
|
|
|
logits = self.driver.add_bias(logits, self._prepared_bias)
|
|
|
else:
|
|
|
|
|
|
|
|
|
x_2d = x.reshape(-1, self.config.hidden_dim)
|
|
|
logits = np.matmul(x_2d, self.weight)
|
|
|
|
|
|
if self.bias is not None:
|
|
|
logits += self.bias
|
|
|
|
|
|
|
|
|
logits = logits.reshape(batch_size, seq_len, self.config.vocab_size)
|
|
|
|
|
|
return logits
|
|
|
|
|
|
def final_linear_projection(
|
|
|
x: np.ndarray,
|
|
|
W: np.ndarray,
|
|
|
b: Optional[np.ndarray] = None,
|
|
|
driver = None
|
|
|
) -> np.ndarray:
|
|
|
"""
|
|
|
Legacy function for backward compatibility.
|
|
|
|
|
|
Args:
|
|
|
x: Input tensor (batch, seq_len, hidden_dim)
|
|
|
W: Weight matrix (hidden_dim, vocab_size)
|
|
|
b: Optional bias vector (vocab_size,)
|
|
|
driver: Optional hardware driver
|
|
|
|
|
|
Returns:
|
|
|
logits: Output logits (batch, seq_len, vocab_size)
|
|
|
"""
|
|
|
warnings.warn(
|
|
|
"final_linear_projection is deprecated, use FinalProjection class instead",
|
|
|
DeprecationWarning
|
|
|
)
|
|
|
|
|
|
config = ProjectionConfig(
|
|
|
hidden_dim=W.shape[0],
|
|
|
vocab_size=W.shape[1],
|
|
|
use_bias=b is not None
|
|
|
)
|
|
|
|
|
|
projection = FinalProjection(config, driver=driver)
|
|
|
projection.weight = W
|
|
|
if b is not None:
|
|
|
projection.bias = b
|
|
|
|
|
|
return projection.forward(x)
|
|
|
|