INV / helium /final_projection.py
Fred808's picture
Upload 256 files
7a0c684 verified
from typing import Optional, Union, Tuple
import numpy as np
from dataclasses import dataclass
import warnings
@dataclass
class ProjectionConfig:
"""Configuration for final projection layer"""
hidden_dim: int
vocab_size: int
use_bias: bool = True
use_fp16: bool = False
use_weight_tying: bool = False
dropout_rate: float = 0.0
initializer_range: float = 0.02
class FinalProjection:
"""
Optimized final projection layer for transformer models with support for:
- Weight tying with input embeddings
- Mixed precision (FP16/FP32)
- Memory-efficient computation
- Optimized matrix multiplication
- Bias fusion
"""
def __init__(
self,
config: ProjectionConfig,
embedding_weights: Optional[np.ndarray] = None,
driver = None
):
"""
Initialize the final projection layer.
Args:
config: Projection configuration
embedding_weights: Optional weights for weight tying with input embeddings
driver: Optional hardware driver for optimized computation
"""
self.config = config
self.driver = driver
if config.use_weight_tying and embedding_weights is not None:
# Tie weights with input embeddings
self.weight = embedding_weights.T # Transpose for projection
else:
# Initialize new weights
self.weight = self._initialize_weights()
if config.use_bias:
self.bias = np.zeros(config.vocab_size, dtype=self._get_dtype())
else:
self.bias = None
# Cache for optimizations
self._setup_cache()
def _get_dtype(self) -> np.dtype:
"""Get the appropriate dtype based on configuration"""
return np.float16 if self.config.use_fp16 else np.float32
def _initialize_weights(self) -> np.ndarray:
"""Initialize projection weights"""
return np.random.normal(
0.0,
self.config.initializer_range,
(self.config.hidden_dim, self.config.vocab_size)
).astype(self._get_dtype())
def _setup_cache(self):
"""Setup computation cache for optimizations"""
self._cached_shapes = {}
if self.driver and hasattr(self.driver, 'prepare_projection'):
self._prepared_weight = self.driver.prepare_projection(self.weight)
if self.bias is not None:
self._prepared_bias = self.driver.prepare_bias(self.bias)
else:
self._prepared_weight = None
self._prepared_bias = None
def _apply_dropout(
self,
x: np.ndarray,
training: bool = False
) -> np.ndarray:
"""Apply dropout if configured"""
if training and self.config.dropout_rate > 0:
mask = np.random.binomial(
1,
1.0 - self.config.dropout_rate,
x.shape
).astype(self._get_dtype()) / (1.0 - self.config.dropout_rate)
return x * mask
return x
def _validate_input(self, x: np.ndarray):
"""Validate input tensor shape and type"""
if x.ndim != 3:
raise ValueError(
f"Expected 3D input tensor (batch, seq_len, hidden_dim), got shape {x.shape}"
)
if x.shape[-1] != self.config.hidden_dim:
raise ValueError(
f"Input hidden dimension {x.shape[-1]} doesn't match "
f"configured hidden_dim {self.config.hidden_dim}"
)
def _optimize_computation(
self,
x: np.ndarray,
batch_size: int,
seq_len: int
) -> np.ndarray:
"""Optimize computation based on input shape and hardware"""
shape_key = (batch_size, seq_len)
# Use cached computation plan if available
if shape_key in self._cached_shapes:
return self._cached_shapes[shape_key](x)
if self.driver and hasattr(self.driver, 'optimized_projection'):
# Use hardware-specific optimizations
compute_plan = self.driver.optimized_projection(
batch_size,
seq_len,
self._prepared_weight,
self._prepared_bias
)
self._cached_shapes[shape_key] = compute_plan
return compute_plan(x)
return None # Fall back to standard computation
def forward(
self,
x: np.ndarray,
training: bool = False
) -> np.ndarray:
"""
Forward pass of the final projection layer.
Args:
x: Input tensor of shape (batch_size, seq_len, hidden_dim)
training: Whether in training mode (enables dropout)
Returns:
logits: Output logits of shape (batch_size, seq_len, vocab_size)
"""
self._validate_input(x)
# Cast input to appropriate dtype
x = x.astype(self._get_dtype())
# Apply dropout during training
x = self._apply_dropout(x, training)
batch_size, seq_len, _ = x.shape
# Try optimized computation path
optimized_result = self._optimize_computation(x, batch_size, seq_len)
if optimized_result is not None:
return optimized_result
# Standard computation path
if self.driver and self._prepared_weight is not None:
# Use prepared weights if available
logits = self.driver.matmul(x, self._prepared_weight)
if self._prepared_bias is not None:
logits = self.driver.add_bias(logits, self._prepared_bias)
else:
# Fallback to NumPy computation
# Reshape for efficient matrix multiplication
x_2d = x.reshape(-1, self.config.hidden_dim)
logits = np.matmul(x_2d, self.weight)
if self.bias is not None:
logits += self.bias
# Reshape back to 3D
logits = logits.reshape(batch_size, seq_len, self.config.vocab_size)
return logits
def final_linear_projection(
x: np.ndarray,
W: np.ndarray,
b: Optional[np.ndarray] = None,
driver = None
) -> np.ndarray:
"""
Legacy function for backward compatibility.
Args:
x: Input tensor (batch, seq_len, hidden_dim)
W: Weight matrix (hidden_dim, vocab_size)
b: Optional bias vector (vocab_size,)
driver: Optional hardware driver
Returns:
logits: Output logits (batch, seq_len, vocab_size)
"""
warnings.warn(
"final_linear_projection is deprecated, use FinalProjection class instead",
DeprecationWarning
)
config = ProjectionConfig(
hidden_dim=W.shape[0],
vocab_size=W.shape[1],
use_bias=b is not None
)
projection = FinalProjection(config, driver=driver)
projection.weight = W
if b is not None:
projection.bias = b
return projection.forward(x)