|
|
"""
|
|
|
Hardware-accelerated multi-modal transformer decoder implementation for Helium virtual GPU
|
|
|
"""
|
|
|
from typing import Optional, Union, Dict, Any, TYPE_CHECKING, List, Tuple
|
|
|
from dataclasses import dataclass
|
|
|
import numpy as np
|
|
|
from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, DType, Device, Layout
|
|
|
from virtual_gpu_driver.src.stream import Stream as ComputeStream
|
|
|
from virtual_gpu_driver.src.stream import StreamManager as KernelSchedule
|
|
|
from .main import get_device, get_default_device
|
|
|
from .layer_norm import HeliumLayerNorm
|
|
|
from .gelu import HeliumGELU
|
|
|
from .multihead_attention import HeliumMultiHeadAttention
|
|
|
from .core.db_manager import HeliumDBManager
|
|
|
from .broadcast import ModalityType, TensorMetadata
|
|
|
|
|
|
@dataclass
|
|
|
class DecoderConfig:
|
|
|
"""Configuration for multi-modal decoder"""
|
|
|
output_modalities: List[ModalityType]
|
|
|
hidden_dim: int
|
|
|
num_layers: int
|
|
|
num_heads: int
|
|
|
intermediate_size: int
|
|
|
max_seq_len: Dict[ModalityType, int]
|
|
|
vocab_size: Optional[int] = None
|
|
|
image_size: Optional[Tuple[int, int]] = None
|
|
|
audio_params: Optional[Dict[str, Any]] = None
|
|
|
use_cache: bool = True
|
|
|
dtype: str = "float16"
|
|
|
|
|
|
def validate(self):
|
|
|
"""Validate configuration"""
|
|
|
for modality in self.output_modalities:
|
|
|
if modality == ModalityType.TEXT and not self.vocab_size:
|
|
|
raise ValueError("vocab_size required for text generation")
|
|
|
elif modality == ModalityType.IMAGE and not self.image_size:
|
|
|
raise ValueError("image_size required for image generation")
|
|
|
elif modality == ModalityType.AUDIO and not self.audio_params:
|
|
|
raise ValueError("audio_params required for audio generation")
|
|
|
|
|
|
if TYPE_CHECKING:
|
|
|
from .main import HeliumTensor
|
|
|
|
|
|
class ModalityProjection:
|
|
|
"""Projects hidden states to modality-specific outputs"""
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: DecoderConfig,
|
|
|
modality: ModalityType,
|
|
|
driver=None
|
|
|
):
|
|
|
self.config = config
|
|
|
self.modality = modality
|
|
|
self.driver = driver
|
|
|
|
|
|
if modality == ModalityType.TEXT:
|
|
|
self.proj = self._create_linear(
|
|
|
config.hidden_dim,
|
|
|
config.vocab_size
|
|
|
)
|
|
|
elif modality == ModalityType.IMAGE:
|
|
|
h, w = config.image_size
|
|
|
self.proj = self._create_linear(
|
|
|
config.hidden_dim,
|
|
|
h * w * 3
|
|
|
)
|
|
|
elif modality == ModalityType.AUDIO:
|
|
|
self.proj = self._create_linear(
|
|
|
config.hidden_dim,
|
|
|
config.audio_params["num_samples"]
|
|
|
)
|
|
|
|
|
|
def _create_linear(self, in_features: int, out_features: int) -> Dict[str, Any]:
|
|
|
"""Create projection layer"""
|
|
|
weight_desc = TensorDescriptor(
|
|
|
shape=(out_features, in_features),
|
|
|
dtype=DType.FLOAT16,
|
|
|
device=Device.VGPU,
|
|
|
layout=Layout.ROW_MAJOR
|
|
|
)
|
|
|
bias_desc = TensorDescriptor(
|
|
|
shape=(out_features,),
|
|
|
dtype=DType.FLOAT16,
|
|
|
device=Device.VGPU,
|
|
|
layout=Layout.ROW_MAJOR
|
|
|
)
|
|
|
return {
|
|
|
'weight': self.driver.allocate_tensor(weight_desc),
|
|
|
'bias': self.driver.allocate_tensor(bias_desc)
|
|
|
}
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
hidden_states: Union[str, "HeliumTensor"]
|
|
|
) -> Union[str, "HeliumTensor"]:
|
|
|
"""Project to modality-specific output space"""
|
|
|
out = self.driver.matmul(hidden_states, self.proj['weight'])
|
|
|
out = self.driver.add(out, self.proj['bias'])
|
|
|
|
|
|
if self.modality == ModalityType.IMAGE:
|
|
|
|
|
|
h, w = self.config.image_size
|
|
|
out = self.driver.reshape(out, (-1, h, w, 3))
|
|
|
elif self.modality == ModalityType.AUDIO:
|
|
|
|
|
|
if self.config.audio_params.get("normalize", True):
|
|
|
out = self.driver.tanh(out)
|
|
|
|
|
|
return out
|
|
|
|
|
|
class HeliumDecoderBlock:
|
|
|
"""
|
|
|
Hardware-accelerated multi-modal transformer decoder block
|
|
|
|
|
|
Implements:
|
|
|
1. Self-attention with causal mask
|
|
|
2. Cross-attention with encoder outputs
|
|
|
3. Feed-forward network
|
|
|
4. Multi-modal output projections
|
|
|
All operations run directly on virtual GPU with modality awareness
|
|
|
"""
|
|
|
def __init__(
|
|
|
self,
|
|
|
config: DecoderConfig,
|
|
|
device_id: Optional[str] = None
|
|
|
):
|
|
|
|
|
|
self.driver = get_device(device_id) if device_id else get_default_device()
|
|
|
self.device_id = device_id
|
|
|
self.stream = ComputeStream(self.driver)
|
|
|
|
|
|
|
|
|
self.db = HeliumDBManager.get_instance()
|
|
|
|
|
|
|
|
|
self.config = config
|
|
|
|
|
|
|
|
|
self.hidden_size = config.hidden_dim
|
|
|
self.num_heads = config.num_heads
|
|
|
self.head_dim = config.hidden_dim // config.num_heads
|
|
|
self.intermediate_size = config.intermediate_size
|
|
|
self.dtype = config.dtype
|
|
|
|
|
|
|
|
|
self.self_attention = HeliumMultiHeadAttention(
|
|
|
hidden_size=self.hidden_size,
|
|
|
num_heads=self.num_heads,
|
|
|
device_id=device_id,
|
|
|
dtype=self.dtype
|
|
|
)
|
|
|
|
|
|
self.cross_attention = HeliumMultiHeadAttention(
|
|
|
hidden_size=self.hidden_size,
|
|
|
num_heads=self.num_heads,
|
|
|
device_id=device_id,
|
|
|
dtype=self.dtype
|
|
|
)
|
|
|
|
|
|
|
|
|
self.ln1 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype)
|
|
|
self.ln2 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype)
|
|
|
self.ln3 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype)
|
|
|
|
|
|
|
|
|
self.ff1 = self._create_linear(self.hidden_size, self.intermediate_size)
|
|
|
self.ff2 = self._create_linear(self.intermediate_size, self.hidden_size)
|
|
|
self.gelu = HeliumGELU(device_id=device_id)
|
|
|
|
|
|
|
|
|
self.output_projections = {
|
|
|
modality: ModalityProjection(config, modality, self.driver)
|
|
|
for modality in config.output_modalities
|
|
|
}
|
|
|
|
|
|
|
|
|
self.schedule = KernelSchedule(self.driver)
|
|
|
|
|
|
|
|
|
self._temp_tensors = {}
|
|
|
self._counter = 0
|
|
|
|
|
|
self.self_attention = HeliumMultiHeadAttention(
|
|
|
hidden_size=hidden_size,
|
|
|
num_heads=num_heads,
|
|
|
device_id=device_id,
|
|
|
dtype=dtype
|
|
|
)
|
|
|
|
|
|
self.cross_attention = HeliumMultiHeadAttention(
|
|
|
hidden_size=hidden_size,
|
|
|
num_heads=num_heads,
|
|
|
device_id=device_id,
|
|
|
dtype=dtype
|
|
|
)
|
|
|
|
|
|
|
|
|
self.ln1 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype)
|
|
|
self.ln2 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype)
|
|
|
self.ln3 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype)
|
|
|
|
|
|
|
|
|
self.ff1 = self._create_linear(hidden_size, intermediate_size)
|
|
|
self.ff2 = self._create_linear(intermediate_size, hidden_size)
|
|
|
self.gelu = HeliumGELU(device_id=device_id)
|
|
|
|
|
|
|
|
|
self.schedule = KernelSchedule(self.driver)
|
|
|
|
|
|
|
|
|
self._temp_tensors = {}
|
|
|
self._counter = 0
|
|
|
|
|
|
def _create_linear(self, in_features: int, out_features: int) -> Dict[str, Any]:
|
|
|
"""Create a linear layer's weight tensors"""
|
|
|
weight_desc = TensorDescriptor(
|
|
|
shape=(out_features, in_features),
|
|
|
dtype=getattr(DType, self.dtype.upper()),
|
|
|
device=Device.VGPU,
|
|
|
layout=Layout.ROW_MAJOR
|
|
|
)
|
|
|
|
|
|
bias_desc = TensorDescriptor(
|
|
|
shape=(out_features,),
|
|
|
dtype=getattr(DType, self.dtype.upper()),
|
|
|
device=Device.VGPU,
|
|
|
layout=Layout.ROW_MAJOR
|
|
|
)
|
|
|
|
|
|
return {
|
|
|
'weight': self.driver.allocate_tensor(weight_desc),
|
|
|
'bias': self.driver.allocate_tensor(bias_desc)
|
|
|
}
|
|
|
|
|
|
def _get_temp_tensor(self, shape: tuple) -> str:
|
|
|
"""Allocate a temporary tensor"""
|
|
|
tensor_id = f"decoder_temp_{self._counter}"
|
|
|
self._counter += 1
|
|
|
|
|
|
desc = TensorDescriptor(
|
|
|
shape=shape,
|
|
|
dtype=getattr(DType, self.dtype.upper()),
|
|
|
device=Device.VGPU,
|
|
|
layout=Layout.ROW_MAJOR
|
|
|
)
|
|
|
|
|
|
self._temp_tensors[tensor_id] = self.driver.allocate_tensor(desc)
|
|
|
return tensor_id
|
|
|
|
|
|
def _free_temp_tensor(self, tensor_id: str):
|
|
|
"""Free a temporary tensor"""
|
|
|
if tensor_id in self._temp_tensors:
|
|
|
self.driver.free_tensor(self._temp_tensors[tensor_id])
|
|
|
del self._temp_tensors[tensor_id]
|
|
|
|
|
|
def __del__(self):
|
|
|
"""Clean up temporary tensors"""
|
|
|
for tensor_id in list(self._temp_tensors.keys()):
|
|
|
self._free_temp_tensor(tensor_id)
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
hidden_states: Union[str, "HeliumTensor"],
|
|
|
target_modality: ModalityType,
|
|
|
encoder_hidden_states: Optional[Union[str, "HeliumTensor"]] = None,
|
|
|
attention_mask: Optional[Union[str, "HeliumTensor"]] = None,
|
|
|
encoder_attention_mask: Optional[Union[str, "HeliumTensor"]] = None,
|
|
|
metadata: Optional[TensorMetadata] = None
|
|
|
) -> Union[str, "HeliumTensor"]:
|
|
|
"""
|
|
|
Forward pass of decoder block
|
|
|
|
|
|
Args:
|
|
|
hidden_states: Input tensor (B, S, H)
|
|
|
encoder_hidden_states: Optional encoder output (B, S_enc, H)
|
|
|
attention_mask: Optional attention mask for self-attention
|
|
|
encoder_attention_mask: Optional mask for encoder-decoder attention
|
|
|
|
|
|
Returns:
|
|
|
Output tensor (B, S, H)
|
|
|
"""
|
|
|
residual = hidden_states
|
|
|
|
|
|
|
|
|
with self.stream:
|
|
|
|
|
|
hidden_states = self.ln1(hidden_states)
|
|
|
|
|
|
|
|
|
hidden_states = self.self_attention(
|
|
|
hidden_states,
|
|
|
attention_mask=attention_mask,
|
|
|
causal_mask=True
|
|
|
)
|
|
|
|
|
|
|
|
|
hidden_states = self.driver.add(hidden_states, residual)
|
|
|
|
|
|
|
|
|
if encoder_hidden_states is not None:
|
|
|
residual = hidden_states
|
|
|
|
|
|
with self.stream:
|
|
|
|
|
|
hidden_states = self.ln2(hidden_states)
|
|
|
|
|
|
|
|
|
hidden_states = self.cross_attention(
|
|
|
hidden_states,
|
|
|
encoder_hidden_states,
|
|
|
attention_mask=encoder_attention_mask
|
|
|
)
|
|
|
|
|
|
|
|
|
hidden_states = self.driver.add(hidden_states, residual)
|
|
|
|
|
|
|
|
|
residual = hidden_states
|
|
|
|
|
|
with self.stream:
|
|
|
|
|
|
hidden_states = self.ln3(hidden_states)
|
|
|
|
|
|
|
|
|
hidden_states = self.driver.matmul(
|
|
|
hidden_states,
|
|
|
self.ff1['weight']
|
|
|
)
|
|
|
hidden_states = self.driver.add(hidden_states, self.ff1['bias'])
|
|
|
hidden_states = self.gelu(hidden_states)
|
|
|
|
|
|
hidden_states = self.driver.matmul(
|
|
|
hidden_states,
|
|
|
self.ff2['weight']
|
|
|
)
|
|
|
hidden_states = self.driver.add(hidden_states, self.ff2['bias'])
|
|
|
|
|
|
|
|
|
hidden_states = self.driver.add(hidden_states, residual)
|
|
|
|
|
|
|
|
|
if target_modality not in self.output_projections:
|
|
|
raise ValueError(f"No projection available for modality {target_modality}")
|
|
|
|
|
|
output = self.output_projections[target_modality].forward(hidden_states)
|
|
|
|
|
|
|
|
|
if metadata is not None:
|
|
|
metadata.modality = target_modality
|
|
|
if target_modality == ModalityType.IMAGE:
|
|
|
h, w = self.config.image_size
|
|
|
metadata.spatial_dims = (h, w)
|
|
|
metadata.channels = 3
|
|
|
elif target_modality == ModalityType.AUDIO:
|
|
|
metadata.sampling_rate = self.config.audio_params.get("sampling_rate")
|
|
|
elif target_modality == ModalityType.TEXT:
|
|
|
metadata.sequence_length = output.shape[1]
|
|
|
|
|
|
return output
|
|
|
|