INV / helium /multihead_attention.py
Fred808's picture
Upload 256 files
7a0c684 verified
import numpy as np
from typing import Optional, Tuple, Dict, Union, List, Any
from dataclasses import dataclass
from enum import Enum
from .softmax import softmax
from .broadcast import ModalityType, TensorMetadata
from .tensor import HeliumTensor
from .attention_utils import AttentionState
from .utils import split_heads, apply_rotary_embedding, fuse_cross_modal_attention
class AttentionType(Enum):
"""Types of attention patterns"""
SELF = "self"
CROSS = "cross"
LOCAL = "local"
SPARSE = "sparse"
GLOBAL = "global"
@dataclass
class AttentionConfig:
"""Configuration for multi-modal attention"""
attention_type: AttentionType
num_heads: int
hidden_dim: int
cross_modality_fusion: str = "additive"
use_rotary: bool = False
class HeliumMultiHeadAttention:
"""
Multi-modal attention implementation with support for:
- Cross-modal attention
- Modality-specific patterns
- Local/sparse attention
- Rotary embeddings
- Fusion mechanisms
"""
def __init__(self, config: AttentionConfig, device_id: Optional[str] = None):
self.config = config
self.device_id = device_id
self.head_dim = config.hidden_dim // config.num_heads
# Initialize modality-specific projections
self.projections = self._create_projections()
# Initialize output projection
self.output_projection = self._create_projection(scale=1.0)
# Cache for attention patterns
self.pattern_cache: Dict[str, np.ndarray] = {}
def _create_projections(self) -> Dict[str, Dict[str, Any]]:
"""Create projection matrices for Q,K,V for each modality"""
projections = {}
for modality in ModalityType:
# Get modality-specific scaling
scale = 1.0
if modality == ModalityType.IMAGE:
scale = np.sqrt(self.head_dim / 64)
elif modality == ModalityType.AUDIO:
scale = np.sqrt(self.head_dim / 32)
# Create projections
q_proj = self._create_projection(scale=scale)
k_proj = self._create_projection(scale=scale)
v_proj = self._create_projection(scale=scale)
projections[modality] = {
'query': q_proj,
'key': k_proj,
'value': v_proj
}
return projections
def _create_projection(self, scale: float = 1.0) -> Dict[str, Union[np.ndarray, HeliumTensor]]:
"""Create a single projection matrix"""
std = scale * np.sqrt(2.0 / (2.0 * self.config.hidden_dim))
weight = np.random.normal(0, std, (self.config.hidden_dim, self.config.hidden_dim))
bias = np.zeros(self.config.hidden_dim)
if self.device_id:
# Move to device if specified
weight = HeliumTensor(weight, device=self.device_id)
bias = HeliumTensor(bias, device=self.device_id)
return {'weight': weight, 'bias': bias}
def forward(
self,
hidden_states: Union[str, HeliumTensor],
attention_mask: Optional[Union[str, HeliumTensor]] = None,
modality: Optional[ModalityType] = None,
cross_states: Optional[Union[str, HeliumTensor]] = None,
cross_modality: Optional[ModalityType] = None,
metadata: Optional[TensorMetadata] = None
) -> Tuple[Union[str, HeliumTensor], Dict[str, Any]]:
"""
Multi-modal attention forward pass
"""
# Initialize computation state
state = AttentionState(hidden_states.device if hasattr(hidden_states, 'device') else None, "mm_attn")
# Get projection matrices
mod = modality or ModalityType.TEXT
projections = self.projections[mod]
# Project inputs
q = driver.matmul(hidden_states, projections['query']['weight'])
k = q if cross_states is None else driver.matmul(cross_states, projections['key']['weight'])
v = k
# Split heads with modality awareness
q = split_heads(q, self.config.num_heads, hidden_states.device, modality)
k = split_heads(k, self.config.num_heads, hidden_states.device, cross_modality or modality)
v = split_heads(v, self.config.num_heads, hidden_states.device, cross_modality or modality)
# Apply rotary embeddings if configured
if self.config.use_rotary:
seq_len = hidden_states.shape[1]
q = apply_rotary_embedding(q, seq_len, self.head_dim, hidden_states.device)
k = apply_rotary_embedding(k, seq_len, self.head_dim, hidden_states.device)
# Handle cross-modal attention
if cross_states is not None and cross_modality != modality:
q, k, v = fuse_cross_modal_attention(
q, k, v,
modality,
cross_modality,
self.config.cross_modality_fusion,
hidden_states.device,
state
)
# Get attention mask
if attention_mask is None and self.config.attention_type != AttentionType.GLOBAL:
attention_mask = self._get_attention_mask(
modality or ModalityType.TEXT,
cross_modality or modality or ModalityType.TEXT,
q.shape[2],
k.shape[2]
)
# Compute attention
scale = np.sqrt(self.head_dim)
if modality == ModalityType.IMAGE:
scale *= 2.0
attn_output, _ = scaled_dot_product_attention(
q, k, v,
mask=attention_mask,
scale=scale,
driver=hidden_states.device
)
# Combine heads
attn_output = driver.reshape(attn_output, (
attn_output.shape[0],
attn_output.shape[2],
self.config.hidden_dim
))
# Project output
output = driver.matmul(attn_output, self.output_projection['weight'])
# Add metadata
if metadata:
metadata.modality = modality
metadata.operation = "attention"
metadata.shape = output.shape
return output, {'attention_weights': attn_output}
def _get_attention_mask(
self,
q_modality: ModalityType,
k_modality: ModalityType,
q_length: int,
k_length: int
) -> Optional[Union[str, HeliumTensor]]:
"""Get or create attention mask for given modalities"""
key = (q_modality, k_modality, q_length, k_length)
if key in self.pattern_cache:
return self.pattern_cache[key]
# Create attention mask based on attention type
mask = None
if self.config.attention_type == AttentionType.LOCAL:
# Local attention with sliding window
window = self.config.window_size or q_length // 8
indices = np.arange(q_length)
mask = np.abs(indices[:, None] - indices) > window
elif self.config.attention_type == AttentionType.SPARSE:
# Sparse attention with strided pattern
stride = self.config.sparsity_factor or 8
indices = np.arange(q_length)
mask = (indices[:, None] - indices) % stride != 0
# Add modal-specific patterns
if mask is not None and self.config.modality_specific:
if q_modality == ModalityType.IMAGE:
# Add local 2D structure for images
h = w = int(np.sqrt(q_length))
if h * w == q_length: # Perfect square
i, j = np.meshgrid(np.arange(h), np.arange(w))
dist = (i[:, None] - i) ** 2 + (j[:, None] - j) ** 2
mask = np.logical_and(mask, dist.reshape(q_length, q_length) > 4)
elif q_modality == ModalityType.AUDIO:
# Add frequency-based patterns for audio
freqs = np.fft.fftfreq(q_length)
mask = np.logical_and(mask,
np.abs(freqs[:, None] - freqs) > 0.25)
if mask is not None and self.device_id:
mask = HeliumTensor(mask, device=self.device_id)
self.pattern_cache[key] = mask
return mask
def create_attention_mask(
q_modality: ModalityType,
k_modality: ModalityType,
q_length: int,
k_length: int,
attention_type: AttentionType,
window_size: Optional[int] = None
) -> np.ndarray:
mask = np.ones((q_length, k_length), dtype=np.float32)
if attention_type == AttentionType.LOCAL and window_size:
# Create local attention pattern
for i in range(q_length):
start = max(0, i - window_size)
end = min(k_length, i + window_size + 1)
mask[i, :start] = 0
mask[i, end:] = 0
elif attention_type == AttentionType.SPARSE:
# Create sparse attention pattern
stride = max(1, k_length // 8) # Example: attend to every 8th position
mask[:, ::stride] = 1
mask[:, :] = 0
# Modality-specific masking
if q_modality != k_modality:
if q_modality == ModalityType.TEXT and k_modality == ModalityType.IMAGE:
# Text can attend to full image
pass
elif q_modality == ModalityType.IMAGE and k_modality == ModalityType.TEXT:
# Image attends to text sparsely
mask[:, ::2] = 1 # Example: attend to every other text token
mask[:, 1::2] = 0
return mask
def split_heads(
x_name: str,
num_heads: int,
driver,
state: AttentionState,
modality: Optional[ModalityType] = None
) -> str:
"""
Split the last dimension into (num_heads, head_dim) with modality-specific processing
All operations in driver memory
Returns: name of resulting tensor in driver
"""
x = driver.get_tensor(x_name)
batch, seq_len, hidden_dim = x.shape
head_dim = hidden_dim // num_heads
# Apply modality-specific head scaling
if modality:
scale = 1.0
if modality == ModalityType.IMAGE:
# Scale image heads differently
scale = np.sqrt(head_dim / 64) # Example scaling
elif modality == ModalityType.TEXT:
scale = 1.0
x = x * scale
# Reshape and transpose in driver memory
reshaped_name = state.get_temp_tensor(
x.reshape(batch, seq_len, num_heads, head_dim),
"reshaped"
)
# Add modality info to metadata if supported
if hasattr(driver, 'set_tensor_metadata') and modality:
driver.set_tensor_metadata(
reshaped_name,
TensorMetadata(
modality=modality,
shape=x.shape,
dtype=x.dtype
)
)
transposed_name = state.get_temp_tensor(
driver.transpose(reshaped_name, (0, 2, 1, 3)),
"transposed"
)
state.free_temp_tensor(reshaped_name)
return transposed_name
def apply_rotary_embedding(
x_name: str,
seq_len: int,
head_dim: int,
driver,
state: AttentionState,
base: int = 10000
) -> str:
"""Apply rotary positional embeddings"""
x = driver.get_tensor(x_name)
batch_size, num_heads = x.shape[:2]
# Create position indices
position = np.arange(seq_len)
# Create dimension indices
dim = np.arange(head_dim // 2) * 2
# Compute frequencies
freq = 1.0 / (base ** (dim / head_dim))
freq = np.einsum('i,j->ij', position, freq)
# Compute rotations
cos = np.cos(freq)[None, None, :, :]
sin = np.sin(freq)[None, None, :, :]
# Reshape x for rotation
x_reshaped = x.reshape(batch_size, num_heads, seq_len, head_dim // 2, 2)
# Apply rotation
x_rot = np.concatenate([
x_reshaped[..., 0] * cos - x_reshaped[..., 1] * sin,
x_reshaped[..., 0] * sin + x_reshaped[..., 1] * cos
], axis=-1)
rotated_name = state.get_temp_tensor(x_rot, "rotary")
return rotated_name
def fuse_cross_modal_attention(
q_name: str,
k_name: str,
v_name: str,
q_modality: ModalityType,
k_modality: ModalityType,
fusion_type: str,
driver,
state: AttentionState
) -> Tuple[str, str, str]:
"""
Fuse attention across different modalities
Args:
q_name: Query tensor name
k_name: Key tensor name
v_name: Value tensor name
q_modality: Query modality
k_modality: Key modality
fusion_type: Type of fusion (additive, multiplicative, gated)
"""
q = driver.get_tensor(q_name)
k = driver.get_tensor(k_name)
v = driver.get_tensor(v_name)
if fusion_type == "additive":
# Add modality-specific learnable bias
bias_shape = (1, q.shape[1], 1, q.shape[-1])
q_bias = np.zeros(bias_shape)
k_bias = np.zeros(bias_shape)
q_fused_name = state.get_temp_tensor(q + q_bias, "q_fused")
k_fused_name = state.get_temp_tensor(k + k_bias, "k_fused")
v_fused_name = v_name
elif fusion_type == "multiplicative":
# Apply modality-specific scaling
q_scale = np.sqrt(q.shape[-1]) if q_modality == ModalityType.TEXT else 1.0
k_scale = np.sqrt(k.shape[-1]) if k_modality == ModalityType.TEXT else 1.0
q_fused_name = state.get_temp_tensor(q * q_scale, "q_fused")
k_fused_name = state.get_temp_tensor(k * k_scale, "k_fused")
v_fused_name = v_name
elif fusion_type == "gated":
# Learn modality-specific gating
gate_shape = (1, q.shape[1], 1, 1)
q_gate = np.ones(gate_shape) # Initialize to 1
k_gate = np.ones(gate_shape)
q_fused_name = state.get_temp_tensor(q * q_gate, "q_fused")
k_fused_name = state.get_temp_tensor(k * k_gate, "k_fused")
v_fused_name = v_name
return q_fused_name, k_fused_name, v_fused_name
def combine_heads(
x_name: str,
driver,
state: AttentionState,
modality: Optional[ModalityType] = None
) -> str:
"""
Combine heads with modality-specific processing
All operations in driver memory
Returns: name of resulting tensor in driver
"""
x = driver.get_tensor(x_name)
batch, num_heads, seq_len, head_dim = x.shape
# Transpose and reshape in driver memory
transposed_name = state.get_temp_tensor(
driver.transpose(x_name, (0, 2, 1, 3)),
"transposed_back"
)
reshaped_name = state.get_temp_tensor(
driver.reshape(transposed_name, (batch, seq_len, num_heads * head_dim)),
"reshaped_back"
)
state.free_temp_tensor(transposed_name)
return reshaped_name
def __init__(
self,
config: AttentionConfig,
device_id: Optional[str] = None,
driver = None
):
self.config = config
self.driver = driver
self.head_dim = config.hidden_dim // config.num_heads
# Initialize modality-specific projections
self.projections = self._create_projections()
# Cache for attention patterns
self.pattern_cache: Dict[str, np.ndarray] = {}
def _create_projections(self) -> Dict[str, Dict[str, Any]]:
"""Create projection matrices for Q,K,V"""
projections = {}
for modality in ModalityType:
# Get modality-specific scaling
scale = 1.0
if modality == ModalityType.IMAGE:
scale = np.sqrt(self.head_dim / 64)
elif modality == ModalityType.AUDIO:
scale = np.sqrt(self.head_dim / 32)
# Create projections
q_proj = self._create_projection(scale=scale)
k_proj = self._create_projection(scale=scale)
v_proj = self._create_projection(scale=scale)
projections[modality] = {
'query': q_proj,
'key': k_proj,
'value': v_proj
}
return projections
def _create_projection(self, scale: float = 1.0) -> Dict[str, np.ndarray]:
"""Create a single projection matrix"""
std = scale * np.sqrt(2.0 / (2.0 * self.config.hidden_dim))
weight = np.random.normal(0, std, (self.config.hidden_dim, self.config.hidden_dim))
bias = np.zeros(self.config.hidden_dim)
if hasattr(self.driver, 'to_gpu'):
weight = self.driver.to_gpu(weight)
bias = self.driver.to_gpu(bias)
return {'weight': weight, 'bias': bias}
def forward(
self,
hidden_states: Union[str, "HeliumTensor"],
attention_mask: Optional[Union[str, "HeliumTensor"]] = None,
modality: Optional[ModalityType] = None,
cross_states: Optional[Union[str, "HeliumTensor"]] = None,
cross_modality: Optional[ModalityType] = None,
metadata: Optional[TensorMetadata] = None
) -> Tuple[Union[str, "HeliumTensor"], Dict[str, Any]]:
"""
Multi-modal attention forward pass
"""
# Initialize computation state
state = AttentionState(self.driver, f"mm_attn")
# Get input tensors from names/references
if isinstance(hidden_states, str):
query = self.driver.get_tensor(hidden_states)
else:
query = hidden_states
# Project query
q_proj = self.projections[modality or ModalityType.TEXT]['query']
key = query if cross_states is None else cross_states
value = key
# Project and split heads
q = self.driver.matmul(query, q_proj['weight'])
k = self.driver.matmul(key, q_proj['weight'])
v = self.driver.matmul(value, q_proj['weight'])
# Split heads with modality awareness
q = split_heads(q, self.config.num_heads, self.driver, modality)
k = split_heads(k, self.config.num_heads, self.driver, cross_modality or modality)
v = split_heads(v, self.config.num_heads, self.driver, cross_modality or modality)
# Apply rotary embeddings if configured
if self.config.use_rotary:
q = apply_rotary_embedding(q, query.shape[1], self.head_dim, self.driver)
k = apply_rotary_embedding(k, key.shape[1], self.head_dim, self.driver)
# Handle cross-modal attention
if cross_states is not None and cross_modality != modality:
q, k, v = fuse_cross_modal_attention(
q, k, v,
modality,
cross_modality,
self.config.cross_modality_fusion,
self.driver,
state
)
# Get attention mask
if attention_mask is None and self.config.attention_type != AttentionType.GLOBAL:
attention_mask = self._get_attention_mask(
modality or ModalityType.TEXT,
cross_modality or modality or ModalityType.TEXT,
query.shape[1],
key.shape[1]
)
# Compute attention with scaling
scale = np.sqrt(self.head_dim)
if modality == ModalityType.IMAGE:
scale *= 2.0 # Stronger scaling for image attention
attn_output = scaled_dot_product_attention(
q, k, v,
mask=attention_mask,
scale=scale,
driver=self.driver
)
# Combine heads
attn_output = driver.reshape(attn_output, (
attn_output.shape[0],
attn_output.shape[2],
self.config.hidden_dim
))
# Project back
output = driver.matmul(attn_output, self.output_projection['weight'])
# Add metadata
if metadata:
metadata.modality = modality
metadata.operation = "attention"
metadata.shape = output.shape
return output, {'attention_weights': attn_output}
def multihead_attention(
x_name: str,
Wq_name: str,
Wk_name: str,
Wv_name: str,
Wo_name: str,
num_heads: int,
mask_name: Optional[str] = None,
driver = None,
chip_id: int = 0,
sm_id: int = 0,
scheduler = None
) -> Tuple[str, str]:
"""
All tensors referenced by their names in driver storage
Returns: (output_name, attention_weights_name) in driver
"""
if driver is None:
raise ValueError("Driver is required for GPU-backed attention")
state = AttentionState(driver, f"mha_{chip_id}_{sm_id}")
# Compute Q, K, V projections in driver memory
Q_name = state.get_temp_tensor(
driver.matmul(x_name, Wq_name, chip_id=chip_id, sm_id=sm_id),
"Q"
)
K_name = state.get_temp_tensor(
driver.matmul(x_name, Wk_name, chip_id=chip_id, sm_id=sm_id),
"K"
)
V_name = state.get_temp_tensor(
driver.matmul(x_name, Wv_name, chip_id=chip_id, sm_id=sm_id),
"V"
)
# Split heads
Q_heads_name = split_heads(Q_name, num_heads, driver, state)
K_heads_name = split_heads(K_name, num_heads, driver, state)
V_heads_name = split_heads(V_name, num_heads, driver, state)
# Free original projections
state.free_temp_tensor(Q_name)
state.free_temp_tensor(K_name)
state.free_temp_tensor(V_name)
# Compute attention
attn_output_name, attn_weights_name = scaled_dot_product_attention(
Q_heads_name, K_heads_name, V_heads_name,
mask_name=mask_name,
driver=driver,
chip_id=chip_id,
sm_id=sm_id,
scheduler=scheduler
)
# Free split heads
state.free_temp_tensor(Q_heads_name)
state.free_temp_tensor(K_heads_name)
state.free_temp_tensor(V_heads_name)
# Combine heads
combined_name = combine_heads(attn_output_name, driver, state)
state.free_temp_tensor(attn_output_name)
# Final output projection
output_name = state.get_temp_tensor(
driver.matmul(combined_name, Wo_name, chip_id=chip_id, sm_id=sm_id),
"output"
)
state.free_temp_tensor(combined_name)
return output_name, attn_weights_name