from typing import Optional, Union, Tuple import numpy as np from dataclasses import dataclass import warnings @dataclass class ProjectionConfig: """Configuration for final projection layer""" hidden_dim: int vocab_size: int use_bias: bool = True use_fp16: bool = False use_weight_tying: bool = False dropout_rate: float = 0.0 initializer_range: float = 0.02 class FinalProjection: """ Optimized final projection layer for transformer models with support for: - Weight tying with input embeddings - Mixed precision (FP16/FP32) - Memory-efficient computation - Optimized matrix multiplication - Bias fusion """ def __init__( self, config: ProjectionConfig, embedding_weights: Optional[np.ndarray] = None, driver = None ): """ Initialize the final projection layer. Args: config: Projection configuration embedding_weights: Optional weights for weight tying with input embeddings driver: Optional hardware driver for optimized computation """ self.config = config self.driver = driver if config.use_weight_tying and embedding_weights is not None: # Tie weights with input embeddings self.weight = embedding_weights.T # Transpose for projection else: # Initialize new weights self.weight = self._initialize_weights() if config.use_bias: self.bias = np.zeros(config.vocab_size, dtype=self._get_dtype()) else: self.bias = None # Cache for optimizations self._setup_cache() def _get_dtype(self) -> np.dtype: """Get the appropriate dtype based on configuration""" return np.float16 if self.config.use_fp16 else np.float32 def _initialize_weights(self) -> np.ndarray: """Initialize projection weights""" return np.random.normal( 0.0, self.config.initializer_range, (self.config.hidden_dim, self.config.vocab_size) ).astype(self._get_dtype()) def _setup_cache(self): """Setup computation cache for optimizations""" self._cached_shapes = {} if self.driver and hasattr(self.driver, 'prepare_projection'): self._prepared_weight = self.driver.prepare_projection(self.weight) if self.bias is not None: self._prepared_bias = self.driver.prepare_bias(self.bias) else: self._prepared_weight = None self._prepared_bias = None def _apply_dropout( self, x: np.ndarray, training: bool = False ) -> np.ndarray: """Apply dropout if configured""" if training and self.config.dropout_rate > 0: mask = np.random.binomial( 1, 1.0 - self.config.dropout_rate, x.shape ).astype(self._get_dtype()) / (1.0 - self.config.dropout_rate) return x * mask return x def _validate_input(self, x: np.ndarray): """Validate input tensor shape and type""" if x.ndim != 3: raise ValueError( f"Expected 3D input tensor (batch, seq_len, hidden_dim), got shape {x.shape}" ) if x.shape[-1] != self.config.hidden_dim: raise ValueError( f"Input hidden dimension {x.shape[-1]} doesn't match " f"configured hidden_dim {self.config.hidden_dim}" ) def _optimize_computation( self, x: np.ndarray, batch_size: int, seq_len: int ) -> np.ndarray: """Optimize computation based on input shape and hardware""" shape_key = (batch_size, seq_len) # Use cached computation plan if available if shape_key in self._cached_shapes: return self._cached_shapes[shape_key](x) if self.driver and hasattr(self.driver, 'optimized_projection'): # Use hardware-specific optimizations compute_plan = self.driver.optimized_projection( batch_size, seq_len, self._prepared_weight, self._prepared_bias ) self._cached_shapes[shape_key] = compute_plan return compute_plan(x) return None # Fall back to standard computation def forward( self, x: np.ndarray, training: bool = False ) -> np.ndarray: """ Forward pass of the final projection layer. Args: x: Input tensor of shape (batch_size, seq_len, hidden_dim) training: Whether in training mode (enables dropout) Returns: logits: Output logits of shape (batch_size, seq_len, vocab_size) """ self._validate_input(x) # Cast input to appropriate dtype x = x.astype(self._get_dtype()) # Apply dropout during training x = self._apply_dropout(x, training) batch_size, seq_len, _ = x.shape # Try optimized computation path optimized_result = self._optimize_computation(x, batch_size, seq_len) if optimized_result is not None: return optimized_result # Standard computation path if self.driver and self._prepared_weight is not None: # Use prepared weights if available logits = self.driver.matmul(x, self._prepared_weight) if self._prepared_bias is not None: logits = self.driver.add_bias(logits, self._prepared_bias) else: # Fallback to NumPy computation # Reshape for efficient matrix multiplication x_2d = x.reshape(-1, self.config.hidden_dim) logits = np.matmul(x_2d, self.weight) if self.bias is not None: logits += self.bias # Reshape back to 3D logits = logits.reshape(batch_size, seq_len, self.config.vocab_size) return logits def final_linear_projection( x: np.ndarray, W: np.ndarray, b: Optional[np.ndarray] = None, driver = None ) -> np.ndarray: """ Legacy function for backward compatibility. Args: x: Input tensor (batch, seq_len, hidden_dim) W: Weight matrix (hidden_dim, vocab_size) b: Optional bias vector (vocab_size,) driver: Optional hardware driver Returns: logits: Output logits (batch, seq_len, vocab_size) """ warnings.warn( "final_linear_projection is deprecated, use FinalProjection class instead", DeprecationWarning ) config = ProjectionConfig( hidden_dim=W.shape[0], vocab_size=W.shape[1], use_bias=b is not None ) projection = FinalProjection(config, driver=driver) projection.weight = W if b is not None: projection.bias = b return projection.forward(x)