INV / helium /decoder.py
Fred808's picture
Upload 256 files
7a0c684 verified
"""
Hardware-accelerated multi-modal transformer decoder implementation for Helium virtual GPU
"""
from typing import Optional, Union, Dict, Any, TYPE_CHECKING, List, Tuple
from dataclasses import dataclass
import numpy as np
from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, DType, Device, Layout
from virtual_gpu_driver.src.stream import Stream as ComputeStream
from virtual_gpu_driver.src.stream import StreamManager as KernelSchedule
from .main import get_device, get_default_device
from .layer_norm import HeliumLayerNorm
from .gelu import HeliumGELU
from .multihead_attention import HeliumMultiHeadAttention
from .core.db_manager import HeliumDBManager
from .broadcast import ModalityType, TensorMetadata
@dataclass
class DecoderConfig:
"""Configuration for multi-modal decoder"""
output_modalities: List[ModalityType]
hidden_dim: int
num_layers: int
num_heads: int
intermediate_size: int
max_seq_len: Dict[ModalityType, int]
vocab_size: Optional[int] = None # For text generation
image_size: Optional[Tuple[int, int]] = None # For image generation
audio_params: Optional[Dict[str, Any]] = None # For audio generation
use_cache: bool = True
dtype: str = "float16"
def validate(self):
"""Validate configuration"""
for modality in self.output_modalities:
if modality == ModalityType.TEXT and not self.vocab_size:
raise ValueError("vocab_size required for text generation")
elif modality == ModalityType.IMAGE and not self.image_size:
raise ValueError("image_size required for image generation")
elif modality == ModalityType.AUDIO and not self.audio_params:
raise ValueError("audio_params required for audio generation")
if TYPE_CHECKING:
from .main import HeliumTensor
class ModalityProjection:
"""Projects hidden states to modality-specific outputs"""
def __init__(
self,
config: DecoderConfig,
modality: ModalityType,
driver=None
):
self.config = config
self.modality = modality
self.driver = driver
if modality == ModalityType.TEXT:
self.proj = self._create_linear(
config.hidden_dim,
config.vocab_size
)
elif modality == ModalityType.IMAGE:
h, w = config.image_size
self.proj = self._create_linear(
config.hidden_dim,
h * w * 3 # RGB channels
)
elif modality == ModalityType.AUDIO:
self.proj = self._create_linear(
config.hidden_dim,
config.audio_params["num_samples"]
)
def _create_linear(self, in_features: int, out_features: int) -> Dict[str, Any]:
"""Create projection layer"""
weight_desc = TensorDescriptor(
shape=(out_features, in_features),
dtype=DType.FLOAT16,
device=Device.VGPU,
layout=Layout.ROW_MAJOR
)
bias_desc = TensorDescriptor(
shape=(out_features,),
dtype=DType.FLOAT16,
device=Device.VGPU,
layout=Layout.ROW_MAJOR
)
return {
'weight': self.driver.allocate_tensor(weight_desc),
'bias': self.driver.allocate_tensor(bias_desc)
}
def forward(
self,
hidden_states: Union[str, "HeliumTensor"]
) -> Union[str, "HeliumTensor"]:
"""Project to modality-specific output space"""
out = self.driver.matmul(hidden_states, self.proj['weight'])
out = self.driver.add(out, self.proj['bias'])
if self.modality == ModalityType.IMAGE:
# Reshape to image format (B, H, W, C)
h, w = self.config.image_size
out = self.driver.reshape(out, (-1, h, w, 3))
elif self.modality == ModalityType.AUDIO:
# Apply audio-specific processing
if self.config.audio_params.get("normalize", True):
out = self.driver.tanh(out)
return out
class HeliumDecoderBlock:
"""
Hardware-accelerated multi-modal transformer decoder block
Implements:
1. Self-attention with causal mask
2. Cross-attention with encoder outputs
3. Feed-forward network
4. Multi-modal output projections
All operations run directly on virtual GPU with modality awareness
"""
def __init__(
self,
config: DecoderConfig,
device_id: Optional[str] = None
):
# Initialize device and stream
self.driver = get_device(device_id) if device_id else get_default_device()
self.device_id = device_id
self.stream = ComputeStream(self.driver)
# Initialize database connection
self.db = HeliumDBManager.get_instance()
# Store configuration
self.config = config
# Architecture parameters
self.hidden_size = config.hidden_dim
self.num_heads = config.num_heads
self.head_dim = config.hidden_dim // config.num_heads
self.intermediate_size = config.intermediate_size
self.dtype = config.dtype
# Initialize layer components
self.self_attention = HeliumMultiHeadAttention(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
device_id=device_id,
dtype=self.dtype
)
self.cross_attention = HeliumMultiHeadAttention(
hidden_size=self.hidden_size,
num_heads=self.num_heads,
device_id=device_id,
dtype=self.dtype
)
# Layer norms
self.ln1 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype)
self.ln2 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype)
self.ln3 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype)
# Feed-forward layers
self.ff1 = self._create_linear(self.hidden_size, self.intermediate_size)
self.ff2 = self._create_linear(self.intermediate_size, self.hidden_size)
self.gelu = HeliumGELU(device_id=device_id)
# Initialize modality-specific output projections
self.output_projections = {
modality: ModalityProjection(config, modality, self.driver)
for modality in config.output_modalities
}
# Operation scheduling
self.schedule = KernelSchedule(self.driver)
# Track allocated tensors
self._temp_tensors = {}
self._counter = 0
# Initialize layer components
self.self_attention = HeliumMultiHeadAttention(
hidden_size=hidden_size,
num_heads=num_heads,
device_id=device_id,
dtype=dtype
)
self.cross_attention = HeliumMultiHeadAttention(
hidden_size=hidden_size,
num_heads=num_heads,
device_id=device_id,
dtype=dtype
)
# Layer norms
self.ln1 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype)
self.ln2 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype)
self.ln3 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype)
# Feed-forward layers
self.ff1 = self._create_linear(hidden_size, intermediate_size)
self.ff2 = self._create_linear(intermediate_size, hidden_size)
self.gelu = HeliumGELU(device_id=device_id)
# Operation scheduling
self.schedule = KernelSchedule(self.driver)
# Track allocated tensors
self._temp_tensors = {}
self._counter = 0
def _create_linear(self, in_features: int, out_features: int) -> Dict[str, Any]:
"""Create a linear layer's weight tensors"""
weight_desc = TensorDescriptor(
shape=(out_features, in_features),
dtype=getattr(DType, self.dtype.upper()),
device=Device.VGPU,
layout=Layout.ROW_MAJOR
)
bias_desc = TensorDescriptor(
shape=(out_features,),
dtype=getattr(DType, self.dtype.upper()),
device=Device.VGPU,
layout=Layout.ROW_MAJOR
)
return {
'weight': self.driver.allocate_tensor(weight_desc),
'bias': self.driver.allocate_tensor(bias_desc)
}
def _get_temp_tensor(self, shape: tuple) -> str:
"""Allocate a temporary tensor"""
tensor_id = f"decoder_temp_{self._counter}"
self._counter += 1
desc = TensorDescriptor(
shape=shape,
dtype=getattr(DType, self.dtype.upper()),
device=Device.VGPU,
layout=Layout.ROW_MAJOR
)
self._temp_tensors[tensor_id] = self.driver.allocate_tensor(desc)
return tensor_id
def _free_temp_tensor(self, tensor_id: str):
"""Free a temporary tensor"""
if tensor_id in self._temp_tensors:
self.driver.free_tensor(self._temp_tensors[tensor_id])
del self._temp_tensors[tensor_id]
def __del__(self):
"""Clean up temporary tensors"""
for tensor_id in list(self._temp_tensors.keys()):
self._free_temp_tensor(tensor_id)
def forward(
self,
hidden_states: Union[str, "HeliumTensor"],
target_modality: ModalityType,
encoder_hidden_states: Optional[Union[str, "HeliumTensor"]] = None,
attention_mask: Optional[Union[str, "HeliumTensor"]] = None,
encoder_attention_mask: Optional[Union[str, "HeliumTensor"]] = None,
metadata: Optional[TensorMetadata] = None
) -> Union[str, "HeliumTensor"]:
"""
Forward pass of decoder block
Args:
hidden_states: Input tensor (B, S, H)
encoder_hidden_states: Optional encoder output (B, S_enc, H)
attention_mask: Optional attention mask for self-attention
encoder_attention_mask: Optional mask for encoder-decoder attention
Returns:
Output tensor (B, S, H)
"""
residual = hidden_states
# Self attention branch
with self.stream:
# Layer norm 1
hidden_states = self.ln1(hidden_states)
# Self attention
hidden_states = self.self_attention(
hidden_states,
attention_mask=attention_mask,
causal_mask=True # Always use causal mask in decoder
)
# Residual connection
hidden_states = self.driver.add(hidden_states, residual)
# Cross attention branch (if encoder present)
if encoder_hidden_states is not None:
residual = hidden_states
with self.stream:
# Layer norm 2
hidden_states = self.ln2(hidden_states)
# Cross attention
hidden_states = self.cross_attention(
hidden_states,
encoder_hidden_states,
attention_mask=encoder_attention_mask
)
# Residual connection
hidden_states = self.driver.add(hidden_states, residual)
# Feed-forward branch
residual = hidden_states
with self.stream:
# Layer norm 3
hidden_states = self.ln3(hidden_states)
# Feed-forward
hidden_states = self.driver.matmul(
hidden_states,
self.ff1['weight']
)
hidden_states = self.driver.add(hidden_states, self.ff1['bias'])
hidden_states = self.gelu(hidden_states)
hidden_states = self.driver.matmul(
hidden_states,
self.ff2['weight']
)
hidden_states = self.driver.add(hidden_states, self.ff2['bias'])
# Final residual
hidden_states = self.driver.add(hidden_states, residual)
# Project to target modality
if target_modality not in self.output_projections:
raise ValueError(f"No projection available for modality {target_modality}")
output = self.output_projections[target_modality].forward(hidden_states)
# Update metadata if provided
if metadata is not None:
metadata.modality = target_modality
if target_modality == ModalityType.IMAGE:
h, w = self.config.image_size
metadata.spatial_dims = (h, w)
metadata.channels = 3
elif target_modality == ModalityType.AUDIO:
metadata.sampling_rate = self.config.audio_params.get("sampling_rate")
elif target_modality == ModalityType.TEXT:
metadata.sequence_length = output.shape[1]
return output