INV / helium /attention_utils.py
Fred808's picture
Upload 256 files
7a0c684 verified
from typing import Dict, List, Optional, Tuple, Union
import numpy as np
from .broadcast import ModalityType
class AttentionState:
"""State tracking for attention computations"""
def __init__(self, driver, name: str):
self.driver = driver
self.name = name
self.stored_tensors: Dict[str, str] = {}
def split_heads(
x: Union[str, "HeliumTensor"],
num_heads: int,
driver,
modality: Optional[ModalityType] = None
) -> Union[str, "HeliumTensor"]:
"""Split hidden dim into multiple heads"""
if isinstance(x, str):
x = driver.get_tensor(x)
batch_size, seq_len, hidden_dim = x.shape
head_dim = hidden_dim // num_heads
# Reshape and transpose
x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim))
x = driver.transpose(x, (0, 2, 1, 3))
# Apply modality-specific scaling
if modality:
scale = 1.0
if modality == ModalityType.IMAGE:
scale = np.sqrt(head_dim / 64)
elif modality == ModalityType.AUDIO:
scale = np.sqrt(head_dim / 32)
if scale != 1.0:
x = driver.mul_scalar(x, scale)
return x
def apply_rotary_embedding(
x: Union[str, "HeliumTensor"],
seq_len: int,
head_dim: int,
driver
) -> Union[str, "HeliumTensor"]:
"""Apply rotary positional embeddings"""
if isinstance(x, str):
x = driver.get_tensor(x)
# Generate position indices
pos = np.arange(seq_len)
# Generate frequencies
freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim)
angles = pos[:, None] * freqs[None, :]
# Generate rotation matrix elements
cos = np.cos(angles).reshape(seq_len, -1)
sin = np.sin(angles).reshape(seq_len, -1)
# Move to device
cos = driver.to_gpu(cos)
sin = driver.to_gpu(sin)
# Apply rotations
x_rot = driver.matmul(x, cos) - driver.matmul(x, sin)
x = driver.add(x, x_rot)
return x
def fuse_cross_modal_attention(
q: Union[str, "HeliumTensor"],
k: Union[str, "HeliumTensor"],
v: Union[str, "HeliumTensor"],
q_modality: ModalityType,
kv_modality: ModalityType,
fusion_type: str,
driver,
state: AttentionState
) -> Tuple[Union[str, "HeliumTensor"], Union[str, "HeliumTensor"], Union[str, "HeliumTensor"]]:
"""Fuse cross-modal attention patterns"""
if fusion_type == "additive":
# Simple additive fusion
q = driver.add(q, k)
k = q
elif fusion_type == "multiplicative":
# Element-wise multiplication
q = driver.mul(q, k)
k = q
elif fusion_type == "gated":
# Gated fusion with learned parameters
gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None)))
q = driver.add(
driver.mul(gate, q),
driver.mul(driver.sub(1.0, gate), k)
)
k = q
return q, k, v