| import numpy as np
|
| from typing import Optional, Tuple, Dict, Union, List, Any
|
| from dataclasses import dataclass
|
| from enum import Enum
|
| from .softmax import softmax
|
| from .broadcast import ModalityType, TensorMetadata
|
| from .tensor import HeliumTensor
|
| from .attention_utils import AttentionState
|
| from .utils import split_heads, apply_rotary_embedding, fuse_cross_modal_attention
|
|
|
| class AttentionType(Enum):
|
| """Types of attention patterns"""
|
| SELF = "self"
|
| CROSS = "cross"
|
| LOCAL = "local"
|
| SPARSE = "sparse"
|
| GLOBAL = "global"
|
|
|
| @dataclass
|
| class AttentionConfig:
|
| """Configuration for multi-modal attention"""
|
| attention_type: AttentionType
|
| num_heads: int
|
| hidden_dim: int
|
| cross_modality_fusion: str = "additive"
|
| use_rotary: bool = False
|
|
|
| class HeliumMultiHeadAttention:
|
| """
|
| Multi-modal attention implementation with support for:
|
| - Cross-modal attention
|
| - Modality-specific patterns
|
| - Local/sparse attention
|
| - Rotary embeddings
|
| - Fusion mechanisms
|
| """
|
| def __init__(self, config: AttentionConfig, device_id: Optional[str] = None):
|
| self.config = config
|
| self.device_id = device_id
|
| self.head_dim = config.hidden_dim // config.num_heads
|
|
|
|
|
| self.projections = self._create_projections()
|
|
|
|
|
| self.output_projection = self._create_projection(scale=1.0)
|
|
|
|
|
| self.pattern_cache: Dict[str, np.ndarray] = {}
|
|
|
| def _create_projections(self) -> Dict[str, Dict[str, Any]]:
|
| """Create projection matrices for Q,K,V for each modality"""
|
| projections = {}
|
|
|
| for modality in ModalityType:
|
|
|
| scale = 1.0
|
| if modality == ModalityType.IMAGE:
|
| scale = np.sqrt(self.head_dim / 64)
|
| elif modality == ModalityType.AUDIO:
|
| scale = np.sqrt(self.head_dim / 32)
|
|
|
|
|
| q_proj = self._create_projection(scale=scale)
|
| k_proj = self._create_projection(scale=scale)
|
| v_proj = self._create_projection(scale=scale)
|
|
|
| projections[modality] = {
|
| 'query': q_proj,
|
| 'key': k_proj,
|
| 'value': v_proj
|
| }
|
|
|
| return projections
|
|
|
| def _create_projection(self, scale: float = 1.0) -> Dict[str, Union[np.ndarray, HeliumTensor]]:
|
| """Create a single projection matrix"""
|
| std = scale * np.sqrt(2.0 / (2.0 * self.config.hidden_dim))
|
| weight = np.random.normal(0, std, (self.config.hidden_dim, self.config.hidden_dim))
|
| bias = np.zeros(self.config.hidden_dim)
|
|
|
| if self.device_id:
|
|
|
| weight = HeliumTensor(weight, device=self.device_id)
|
| bias = HeliumTensor(bias, device=self.device_id)
|
|
|
| return {'weight': weight, 'bias': bias}
|
|
|
| def forward(
|
| self,
|
| hidden_states: Union[str, HeliumTensor],
|
| attention_mask: Optional[Union[str, HeliumTensor]] = None,
|
| modality: Optional[ModalityType] = None,
|
| cross_states: Optional[Union[str, HeliumTensor]] = None,
|
| cross_modality: Optional[ModalityType] = None,
|
| metadata: Optional[TensorMetadata] = None
|
| ) -> Tuple[Union[str, HeliumTensor], Dict[str, Any]]:
|
| """
|
| Multi-modal attention forward pass
|
| """
|
|
|
| state = AttentionState(hidden_states.device if hasattr(hidden_states, 'device') else None, "mm_attn")
|
|
|
|
|
| mod = modality or ModalityType.TEXT
|
| projections = self.projections[mod]
|
|
|
|
|
| q = driver.matmul(hidden_states, projections['query']['weight'])
|
| k = q if cross_states is None else driver.matmul(cross_states, projections['key']['weight'])
|
| v = k
|
|
|
|
|
| q = split_heads(q, self.config.num_heads, hidden_states.device, modality)
|
| k = split_heads(k, self.config.num_heads, hidden_states.device, cross_modality or modality)
|
| v = split_heads(v, self.config.num_heads, hidden_states.device, cross_modality or modality)
|
|
|
|
|
| if self.config.use_rotary:
|
| seq_len = hidden_states.shape[1]
|
| q = apply_rotary_embedding(q, seq_len, self.head_dim, hidden_states.device)
|
| k = apply_rotary_embedding(k, seq_len, self.head_dim, hidden_states.device)
|
|
|
|
|
| if cross_states is not None and cross_modality != modality:
|
| q, k, v = fuse_cross_modal_attention(
|
| q, k, v,
|
| modality,
|
| cross_modality,
|
| self.config.cross_modality_fusion,
|
| hidden_states.device,
|
| state
|
| )
|
|
|
|
|
| if attention_mask is None and self.config.attention_type != AttentionType.GLOBAL:
|
| attention_mask = self._get_attention_mask(
|
| modality or ModalityType.TEXT,
|
| cross_modality or modality or ModalityType.TEXT,
|
| q.shape[2],
|
| k.shape[2]
|
| )
|
|
|
|
|
| scale = np.sqrt(self.head_dim)
|
| if modality == ModalityType.IMAGE:
|
| scale *= 2.0
|
|
|
| attn_output, _ = scaled_dot_product_attention(
|
| q, k, v,
|
| mask=attention_mask,
|
| scale=scale,
|
| driver=hidden_states.device
|
| )
|
|
|
|
|
| attn_output = driver.reshape(attn_output, (
|
| attn_output.shape[0],
|
| attn_output.shape[2],
|
| self.config.hidden_dim
|
| ))
|
|
|
|
|
| output = driver.matmul(attn_output, self.output_projection['weight'])
|
|
|
|
|
| if metadata:
|
| metadata.modality = modality
|
| metadata.operation = "attention"
|
| metadata.shape = output.shape
|
|
|
| return output, {'attention_weights': attn_output}
|
|
|
| def _get_attention_mask(
|
| self,
|
| q_modality: ModalityType,
|
| k_modality: ModalityType,
|
| q_length: int,
|
| k_length: int
|
| ) -> Optional[Union[str, HeliumTensor]]:
|
| """Get or create attention mask for given modalities"""
|
| key = (q_modality, k_modality, q_length, k_length)
|
| if key in self.pattern_cache:
|
| return self.pattern_cache[key]
|
|
|
|
|
| mask = None
|
| if self.config.attention_type == AttentionType.LOCAL:
|
|
|
| window = self.config.window_size or q_length // 8
|
| indices = np.arange(q_length)
|
| mask = np.abs(indices[:, None] - indices) > window
|
|
|
| elif self.config.attention_type == AttentionType.SPARSE:
|
|
|
| stride = self.config.sparsity_factor or 8
|
| indices = np.arange(q_length)
|
| mask = (indices[:, None] - indices) % stride != 0
|
|
|
|
|
| if mask is not None and self.config.modality_specific:
|
| if q_modality == ModalityType.IMAGE:
|
|
|
| h = w = int(np.sqrt(q_length))
|
| if h * w == q_length:
|
| i, j = np.meshgrid(np.arange(h), np.arange(w))
|
| dist = (i[:, None] - i) ** 2 + (j[:, None] - j) ** 2
|
| mask = np.logical_and(mask, dist.reshape(q_length, q_length) > 4)
|
|
|
| elif q_modality == ModalityType.AUDIO:
|
|
|
| freqs = np.fft.fftfreq(q_length)
|
| mask = np.logical_and(mask,
|
| np.abs(freqs[:, None] - freqs) > 0.25)
|
|
|
| if mask is not None and self.device_id:
|
| mask = HeliumTensor(mask, device=self.device_id)
|
|
|
| self.pattern_cache[key] = mask
|
| return mask
|
|
|
| def create_attention_mask(
|
| q_modality: ModalityType,
|
| k_modality: ModalityType,
|
| q_length: int,
|
| k_length: int,
|
| attention_type: AttentionType,
|
| window_size: Optional[int] = None
|
| ) -> np.ndarray:
|
| mask = np.ones((q_length, k_length), dtype=np.float32)
|
|
|
| if attention_type == AttentionType.LOCAL and window_size:
|
|
|
| for i in range(q_length):
|
| start = max(0, i - window_size)
|
| end = min(k_length, i + window_size + 1)
|
| mask[i, :start] = 0
|
| mask[i, end:] = 0
|
|
|
| elif attention_type == AttentionType.SPARSE:
|
|
|
| stride = max(1, k_length // 8)
|
| mask[:, ::stride] = 1
|
| mask[:, :] = 0
|
|
|
|
|
| if q_modality != k_modality:
|
| if q_modality == ModalityType.TEXT and k_modality == ModalityType.IMAGE:
|
|
|
| pass
|
| elif q_modality == ModalityType.IMAGE and k_modality == ModalityType.TEXT:
|
|
|
| mask[:, ::2] = 1
|
| mask[:, 1::2] = 0
|
|
|
| return mask
|
|
|
| def split_heads(
|
| x_name: str,
|
| num_heads: int,
|
| driver,
|
| state: AttentionState,
|
| modality: Optional[ModalityType] = None
|
| ) -> str:
|
| """
|
| Split the last dimension into (num_heads, head_dim) with modality-specific processing
|
| All operations in driver memory
|
| Returns: name of resulting tensor in driver
|
| """
|
| x = driver.get_tensor(x_name)
|
| batch, seq_len, hidden_dim = x.shape
|
| head_dim = hidden_dim // num_heads
|
|
|
|
|
| if modality:
|
| scale = 1.0
|
| if modality == ModalityType.IMAGE:
|
|
|
| scale = np.sqrt(head_dim / 64)
|
| elif modality == ModalityType.TEXT:
|
| scale = 1.0
|
|
|
| x = x * scale
|
|
|
|
|
| reshaped_name = state.get_temp_tensor(
|
| x.reshape(batch, seq_len, num_heads, head_dim),
|
| "reshaped"
|
| )
|
|
|
|
|
| if hasattr(driver, 'set_tensor_metadata') and modality:
|
| driver.set_tensor_metadata(
|
| reshaped_name,
|
| TensorMetadata(
|
| modality=modality,
|
| shape=x.shape,
|
| dtype=x.dtype
|
| )
|
| )
|
|
|
| transposed_name = state.get_temp_tensor(
|
| driver.transpose(reshaped_name, (0, 2, 1, 3)),
|
| "transposed"
|
| )
|
|
|
| state.free_temp_tensor(reshaped_name)
|
| return transposed_name
|
|
|
| def apply_rotary_embedding(
|
| x_name: str,
|
| seq_len: int,
|
| head_dim: int,
|
| driver,
|
| state: AttentionState,
|
| base: int = 10000
|
| ) -> str:
|
| """Apply rotary positional embeddings"""
|
| x = driver.get_tensor(x_name)
|
| batch_size, num_heads = x.shape[:2]
|
|
|
|
|
| position = np.arange(seq_len)
|
|
|
| dim = np.arange(head_dim // 2) * 2
|
|
|
|
|
| freq = 1.0 / (base ** (dim / head_dim))
|
| freq = np.einsum('i,j->ij', position, freq)
|
|
|
|
|
| cos = np.cos(freq)[None, None, :, :]
|
| sin = np.sin(freq)[None, None, :, :]
|
|
|
|
|
| x_reshaped = x.reshape(batch_size, num_heads, seq_len, head_dim // 2, 2)
|
|
|
|
|
| x_rot = np.concatenate([
|
| x_reshaped[..., 0] * cos - x_reshaped[..., 1] * sin,
|
| x_reshaped[..., 0] * sin + x_reshaped[..., 1] * cos
|
| ], axis=-1)
|
|
|
| rotated_name = state.get_temp_tensor(x_rot, "rotary")
|
| return rotated_name
|
|
|
| def fuse_cross_modal_attention(
|
| q_name: str,
|
| k_name: str,
|
| v_name: str,
|
| q_modality: ModalityType,
|
| k_modality: ModalityType,
|
| fusion_type: str,
|
| driver,
|
| state: AttentionState
|
| ) -> Tuple[str, str, str]:
|
| """
|
| Fuse attention across different modalities
|
|
|
| Args:
|
| q_name: Query tensor name
|
| k_name: Key tensor name
|
| v_name: Value tensor name
|
| q_modality: Query modality
|
| k_modality: Key modality
|
| fusion_type: Type of fusion (additive, multiplicative, gated)
|
| """
|
| q = driver.get_tensor(q_name)
|
| k = driver.get_tensor(k_name)
|
| v = driver.get_tensor(v_name)
|
|
|
| if fusion_type == "additive":
|
|
|
| bias_shape = (1, q.shape[1], 1, q.shape[-1])
|
| q_bias = np.zeros(bias_shape)
|
| k_bias = np.zeros(bias_shape)
|
|
|
| q_fused_name = state.get_temp_tensor(q + q_bias, "q_fused")
|
| k_fused_name = state.get_temp_tensor(k + k_bias, "k_fused")
|
| v_fused_name = v_name
|
|
|
| elif fusion_type == "multiplicative":
|
|
|
| q_scale = np.sqrt(q.shape[-1]) if q_modality == ModalityType.TEXT else 1.0
|
| k_scale = np.sqrt(k.shape[-1]) if k_modality == ModalityType.TEXT else 1.0
|
|
|
| q_fused_name = state.get_temp_tensor(q * q_scale, "q_fused")
|
| k_fused_name = state.get_temp_tensor(k * k_scale, "k_fused")
|
| v_fused_name = v_name
|
|
|
| elif fusion_type == "gated":
|
|
|
| gate_shape = (1, q.shape[1], 1, 1)
|
| q_gate = np.ones(gate_shape)
|
| k_gate = np.ones(gate_shape)
|
|
|
| q_fused_name = state.get_temp_tensor(q * q_gate, "q_fused")
|
| k_fused_name = state.get_temp_tensor(k * k_gate, "k_fused")
|
| v_fused_name = v_name
|
|
|
| return q_fused_name, k_fused_name, v_fused_name
|
|
|
| def combine_heads(
|
| x_name: str,
|
| driver,
|
| state: AttentionState,
|
| modality: Optional[ModalityType] = None
|
| ) -> str:
|
| """
|
| Combine heads with modality-specific processing
|
| All operations in driver memory
|
| Returns: name of resulting tensor in driver
|
| """
|
| x = driver.get_tensor(x_name)
|
| batch, num_heads, seq_len, head_dim = x.shape
|
|
|
|
|
| transposed_name = state.get_temp_tensor(
|
| driver.transpose(x_name, (0, 2, 1, 3)),
|
| "transposed_back"
|
| )
|
| reshaped_name = state.get_temp_tensor(
|
| driver.reshape(transposed_name, (batch, seq_len, num_heads * head_dim)),
|
| "reshaped_back"
|
| )
|
|
|
| state.free_temp_tensor(transposed_name)
|
| return reshaped_name
|
|
|
| def __init__(
|
| self,
|
| config: AttentionConfig,
|
| device_id: Optional[str] = None,
|
| driver = None
|
| ):
|
| self.config = config
|
| self.driver = driver
|
| self.head_dim = config.hidden_dim // config.num_heads
|
|
|
|
|
| self.projections = self._create_projections()
|
|
|
|
|
| self.pattern_cache: Dict[str, np.ndarray] = {}
|
|
|
| def _create_projections(self) -> Dict[str, Dict[str, Any]]:
|
| """Create projection matrices for Q,K,V"""
|
| projections = {}
|
|
|
| for modality in ModalityType:
|
|
|
| scale = 1.0
|
| if modality == ModalityType.IMAGE:
|
| scale = np.sqrt(self.head_dim / 64)
|
| elif modality == ModalityType.AUDIO:
|
| scale = np.sqrt(self.head_dim / 32)
|
|
|
|
|
| q_proj = self._create_projection(scale=scale)
|
| k_proj = self._create_projection(scale=scale)
|
| v_proj = self._create_projection(scale=scale)
|
|
|
| projections[modality] = {
|
| 'query': q_proj,
|
| 'key': k_proj,
|
| 'value': v_proj
|
| }
|
|
|
| return projections
|
|
|
| def _create_projection(self, scale: float = 1.0) -> Dict[str, np.ndarray]:
|
| """Create a single projection matrix"""
|
| std = scale * np.sqrt(2.0 / (2.0 * self.config.hidden_dim))
|
| weight = np.random.normal(0, std, (self.config.hidden_dim, self.config.hidden_dim))
|
| bias = np.zeros(self.config.hidden_dim)
|
|
|
| if hasattr(self.driver, 'to_gpu'):
|
| weight = self.driver.to_gpu(weight)
|
| bias = self.driver.to_gpu(bias)
|
|
|
| return {'weight': weight, 'bias': bias}
|
|
|
| def forward(
|
| self,
|
| hidden_states: Union[str, "HeliumTensor"],
|
| attention_mask: Optional[Union[str, "HeliumTensor"]] = None,
|
| modality: Optional[ModalityType] = None,
|
| cross_states: Optional[Union[str, "HeliumTensor"]] = None,
|
| cross_modality: Optional[ModalityType] = None,
|
| metadata: Optional[TensorMetadata] = None
|
| ) -> Tuple[Union[str, "HeliumTensor"], Dict[str, Any]]:
|
| """
|
| Multi-modal attention forward pass
|
| """
|
|
|
| state = AttentionState(self.driver, f"mm_attn")
|
|
|
|
|
| if isinstance(hidden_states, str):
|
| query = self.driver.get_tensor(hidden_states)
|
| else:
|
| query = hidden_states
|
|
|
|
|
| q_proj = self.projections[modality or ModalityType.TEXT]['query']
|
| key = query if cross_states is None else cross_states
|
| value = key
|
|
|
|
|
| q = self.driver.matmul(query, q_proj['weight'])
|
| k = self.driver.matmul(key, q_proj['weight'])
|
| v = self.driver.matmul(value, q_proj['weight'])
|
|
|
|
|
| q = split_heads(q, self.config.num_heads, self.driver, modality)
|
| k = split_heads(k, self.config.num_heads, self.driver, cross_modality or modality)
|
| v = split_heads(v, self.config.num_heads, self.driver, cross_modality or modality)
|
|
|
|
|
| if self.config.use_rotary:
|
| q = apply_rotary_embedding(q, query.shape[1], self.head_dim, self.driver)
|
| k = apply_rotary_embedding(k, key.shape[1], self.head_dim, self.driver)
|
|
|
|
|
| if cross_states is not None and cross_modality != modality:
|
| q, k, v = fuse_cross_modal_attention(
|
| q, k, v,
|
| modality,
|
| cross_modality,
|
| self.config.cross_modality_fusion,
|
| self.driver,
|
| state
|
| )
|
|
|
|
|
| if attention_mask is None and self.config.attention_type != AttentionType.GLOBAL:
|
| attention_mask = self._get_attention_mask(
|
| modality or ModalityType.TEXT,
|
| cross_modality or modality or ModalityType.TEXT,
|
| query.shape[1],
|
| key.shape[1]
|
| )
|
|
|
|
|
| scale = np.sqrt(self.head_dim)
|
| if modality == ModalityType.IMAGE:
|
| scale *= 2.0
|
|
|
| attn_output = scaled_dot_product_attention(
|
| q, k, v,
|
| mask=attention_mask,
|
| scale=scale,
|
| driver=self.driver
|
| )
|
|
|
|
|
| attn_output = driver.reshape(attn_output, (
|
| attn_output.shape[0],
|
| attn_output.shape[2],
|
| self.config.hidden_dim
|
| ))
|
|
|
|
|
| output = driver.matmul(attn_output, self.output_projection['weight'])
|
|
|
|
|
| if metadata:
|
| metadata.modality = modality
|
| metadata.operation = "attention"
|
| metadata.shape = output.shape
|
|
|
| return output, {'attention_weights': attn_output}
|
|
|
| def multihead_attention(
|
| x_name: str,
|
| Wq_name: str,
|
| Wk_name: str,
|
| Wv_name: str,
|
| Wo_name: str,
|
| num_heads: int,
|
| mask_name: Optional[str] = None,
|
| driver = None,
|
| chip_id: int = 0,
|
| sm_id: int = 0,
|
| scheduler = None
|
| ) -> Tuple[str, str]:
|
| """
|
| All tensors referenced by their names in driver storage
|
| Returns: (output_name, attention_weights_name) in driver
|
| """
|
| if driver is None:
|
| raise ValueError("Driver is required for GPU-backed attention")
|
|
|
| state = AttentionState(driver, f"mha_{chip_id}_{sm_id}")
|
|
|
|
|
| Q_name = state.get_temp_tensor(
|
| driver.matmul(x_name, Wq_name, chip_id=chip_id, sm_id=sm_id),
|
| "Q"
|
| )
|
| K_name = state.get_temp_tensor(
|
| driver.matmul(x_name, Wk_name, chip_id=chip_id, sm_id=sm_id),
|
| "K"
|
| )
|
| V_name = state.get_temp_tensor(
|
| driver.matmul(x_name, Wv_name, chip_id=chip_id, sm_id=sm_id),
|
| "V"
|
| )
|
|
|
|
|
| Q_heads_name = split_heads(Q_name, num_heads, driver, state)
|
| K_heads_name = split_heads(K_name, num_heads, driver, state)
|
| V_heads_name = split_heads(V_name, num_heads, driver, state)
|
|
|
|
|
| state.free_temp_tensor(Q_name)
|
| state.free_temp_tensor(K_name)
|
| state.free_temp_tensor(V_name)
|
|
|
|
|
| attn_output_name, attn_weights_name = scaled_dot_product_attention(
|
| Q_heads_name, K_heads_name, V_heads_name,
|
| mask_name=mask_name,
|
| driver=driver,
|
| chip_id=chip_id,
|
| sm_id=sm_id,
|
| scheduler=scheduler
|
| )
|
|
|
|
|
| state.free_temp_tensor(Q_heads_name)
|
| state.free_temp_tensor(K_heads_name)
|
| state.free_temp_tensor(V_heads_name)
|
|
|
|
|
| combined_name = combine_heads(attn_output_name, driver, state)
|
| state.free_temp_tensor(attn_output_name)
|
|
|
|
|
| output_name = state.get_temp_tensor(
|
| driver.matmul(combined_name, Wo_name, chip_id=chip_id, sm_id=sm_id),
|
| "output"
|
| )
|
| state.free_temp_tensor(combined_name)
|
|
|
| return output_name, attn_weights_name
|
|
|