| from typing import Dict, List, Optional, Tuple, Union
|
| import numpy as np
|
|
|
| from .broadcast import ModalityType
|
|
|
| class AttentionState:
|
| """State tracking for attention computations"""
|
| def __init__(self, driver, name: str):
|
| self.driver = driver
|
| self.name = name
|
| self.stored_tensors: Dict[str, str] = {}
|
|
|
| def split_heads(
|
| x: Union[str, "HeliumTensor"],
|
| num_heads: int,
|
| driver,
|
| modality: Optional[ModalityType] = None
|
| ) -> Union[str, "HeliumTensor"]:
|
| """Split hidden dim into multiple heads"""
|
| if isinstance(x, str):
|
| x = driver.get_tensor(x)
|
|
|
| batch_size, seq_len, hidden_dim = x.shape
|
| head_dim = hidden_dim // num_heads
|
|
|
|
|
| x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim))
|
| x = driver.transpose(x, (0, 2, 1, 3))
|
|
|
|
|
| if modality:
|
| scale = 1.0
|
| if modality == ModalityType.IMAGE:
|
| scale = np.sqrt(head_dim / 64)
|
| elif modality == ModalityType.AUDIO:
|
| scale = np.sqrt(head_dim / 32)
|
|
|
| if scale != 1.0:
|
| x = driver.mul_scalar(x, scale)
|
|
|
| return x
|
|
|
| def apply_rotary_embedding(
|
| x: Union[str, "HeliumTensor"],
|
| seq_len: int,
|
| head_dim: int,
|
| driver
|
| ) -> Union[str, "HeliumTensor"]:
|
| """Apply rotary positional embeddings"""
|
| if isinstance(x, str):
|
| x = driver.get_tensor(x)
|
|
|
|
|
| pos = np.arange(seq_len)
|
|
|
|
|
| freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim)
|
| angles = pos[:, None] * freqs[None, :]
|
|
|
|
|
| cos = np.cos(angles).reshape(seq_len, -1)
|
| sin = np.sin(angles).reshape(seq_len, -1)
|
|
|
|
|
| cos = driver.to_gpu(cos)
|
| sin = driver.to_gpu(sin)
|
|
|
|
|
| x_rot = driver.matmul(x, cos) - driver.matmul(x, sin)
|
| x = driver.add(x, x_rot)
|
|
|
| return x
|
|
|
| def fuse_cross_modal_attention(
|
| q: Union[str, "HeliumTensor"],
|
| k: Union[str, "HeliumTensor"],
|
| v: Union[str, "HeliumTensor"],
|
| q_modality: ModalityType,
|
| kv_modality: ModalityType,
|
| fusion_type: str,
|
| driver,
|
| state: AttentionState
|
| ) -> Tuple[Union[str, "HeliumTensor"], Union[str, "HeliumTensor"], Union[str, "HeliumTensor"]]:
|
| """Fuse cross-modal attention patterns"""
|
| if fusion_type == "additive":
|
|
|
| q = driver.add(q, k)
|
| k = q
|
| elif fusion_type == "multiplicative":
|
|
|
| q = driver.mul(q, k)
|
| k = q
|
| elif fusion_type == "gated":
|
|
|
| gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None)))
|
| q = driver.add(
|
| driver.mul(gate, q),
|
| driver.mul(driver.sub(1.0, gate), k)
|
| )
|
| k = q
|
|
|
| return q, k, v
|
|
|