""" Hardware-accelerated multi-modal transformer decoder implementation for Helium virtual GPU """ from typing import Optional, Union, Dict, Any, TYPE_CHECKING, List, Tuple from dataclasses import dataclass import numpy as np from virtual_gpu_driver.src.ai.tensor_types import TensorDescriptor, DType, Device, Layout from virtual_gpu_driver.src.stream import Stream as ComputeStream from virtual_gpu_driver.src.stream import StreamManager as KernelSchedule from .main import get_device, get_default_device from .layer_norm import HeliumLayerNorm from .gelu import HeliumGELU from .multihead_attention import HeliumMultiHeadAttention from .core.db_manager import HeliumDBManager from .broadcast import ModalityType, TensorMetadata @dataclass class DecoderConfig: """Configuration for multi-modal decoder""" output_modalities: List[ModalityType] hidden_dim: int num_layers: int num_heads: int intermediate_size: int max_seq_len: Dict[ModalityType, int] vocab_size: Optional[int] = None # For text generation image_size: Optional[Tuple[int, int]] = None # For image generation audio_params: Optional[Dict[str, Any]] = None # For audio generation use_cache: bool = True dtype: str = "float16" def validate(self): """Validate configuration""" for modality in self.output_modalities: if modality == ModalityType.TEXT and not self.vocab_size: raise ValueError("vocab_size required for text generation") elif modality == ModalityType.IMAGE and not self.image_size: raise ValueError("image_size required for image generation") elif modality == ModalityType.AUDIO and not self.audio_params: raise ValueError("audio_params required for audio generation") if TYPE_CHECKING: from .main import HeliumTensor class ModalityProjection: """Projects hidden states to modality-specific outputs""" def __init__( self, config: DecoderConfig, modality: ModalityType, driver=None ): self.config = config self.modality = modality self.driver = driver if modality == ModalityType.TEXT: self.proj = self._create_linear( config.hidden_dim, config.vocab_size ) elif modality == ModalityType.IMAGE: h, w = config.image_size self.proj = self._create_linear( config.hidden_dim, h * w * 3 # RGB channels ) elif modality == ModalityType.AUDIO: self.proj = self._create_linear( config.hidden_dim, config.audio_params["num_samples"] ) def _create_linear(self, in_features: int, out_features: int) -> Dict[str, Any]: """Create projection layer""" weight_desc = TensorDescriptor( shape=(out_features, in_features), dtype=DType.FLOAT16, device=Device.VGPU, layout=Layout.ROW_MAJOR ) bias_desc = TensorDescriptor( shape=(out_features,), dtype=DType.FLOAT16, device=Device.VGPU, layout=Layout.ROW_MAJOR ) return { 'weight': self.driver.allocate_tensor(weight_desc), 'bias': self.driver.allocate_tensor(bias_desc) } def forward( self, hidden_states: Union[str, "HeliumTensor"] ) -> Union[str, "HeliumTensor"]: """Project to modality-specific output space""" out = self.driver.matmul(hidden_states, self.proj['weight']) out = self.driver.add(out, self.proj['bias']) if self.modality == ModalityType.IMAGE: # Reshape to image format (B, H, W, C) h, w = self.config.image_size out = self.driver.reshape(out, (-1, h, w, 3)) elif self.modality == ModalityType.AUDIO: # Apply audio-specific processing if self.config.audio_params.get("normalize", True): out = self.driver.tanh(out) return out class HeliumDecoderBlock: """ Hardware-accelerated multi-modal transformer decoder block Implements: 1. Self-attention with causal mask 2. Cross-attention with encoder outputs 3. Feed-forward network 4. Multi-modal output projections All operations run directly on virtual GPU with modality awareness """ def __init__( self, config: DecoderConfig, device_id: Optional[str] = None ): # Initialize device and stream self.driver = get_device(device_id) if device_id else get_default_device() self.device_id = device_id self.stream = ComputeStream(self.driver) # Initialize database connection self.db = HeliumDBManager.get_instance() # Store configuration self.config = config # Architecture parameters self.hidden_size = config.hidden_dim self.num_heads = config.num_heads self.head_dim = config.hidden_dim // config.num_heads self.intermediate_size = config.intermediate_size self.dtype = config.dtype # Initialize layer components self.self_attention = HeliumMultiHeadAttention( hidden_size=self.hidden_size, num_heads=self.num_heads, device_id=device_id, dtype=self.dtype ) self.cross_attention = HeliumMultiHeadAttention( hidden_size=self.hidden_size, num_heads=self.num_heads, device_id=device_id, dtype=self.dtype ) # Layer norms self.ln1 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype) self.ln2 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype) self.ln3 = HeliumLayerNorm(self.hidden_size, device_id=device_id, dtype=self.dtype) # Feed-forward layers self.ff1 = self._create_linear(self.hidden_size, self.intermediate_size) self.ff2 = self._create_linear(self.intermediate_size, self.hidden_size) self.gelu = HeliumGELU(device_id=device_id) # Initialize modality-specific output projections self.output_projections = { modality: ModalityProjection(config, modality, self.driver) for modality in config.output_modalities } # Operation scheduling self.schedule = KernelSchedule(self.driver) # Track allocated tensors self._temp_tensors = {} self._counter = 0 # Initialize layer components self.self_attention = HeliumMultiHeadAttention( hidden_size=hidden_size, num_heads=num_heads, device_id=device_id, dtype=dtype ) self.cross_attention = HeliumMultiHeadAttention( hidden_size=hidden_size, num_heads=num_heads, device_id=device_id, dtype=dtype ) # Layer norms self.ln1 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype) self.ln2 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype) self.ln3 = HeliumLayerNorm(hidden_size, device_id=device_id, dtype=dtype) # Feed-forward layers self.ff1 = self._create_linear(hidden_size, intermediate_size) self.ff2 = self._create_linear(intermediate_size, hidden_size) self.gelu = HeliumGELU(device_id=device_id) # Operation scheduling self.schedule = KernelSchedule(self.driver) # Track allocated tensors self._temp_tensors = {} self._counter = 0 def _create_linear(self, in_features: int, out_features: int) -> Dict[str, Any]: """Create a linear layer's weight tensors""" weight_desc = TensorDescriptor( shape=(out_features, in_features), dtype=getattr(DType, self.dtype.upper()), device=Device.VGPU, layout=Layout.ROW_MAJOR ) bias_desc = TensorDescriptor( shape=(out_features,), dtype=getattr(DType, self.dtype.upper()), device=Device.VGPU, layout=Layout.ROW_MAJOR ) return { 'weight': self.driver.allocate_tensor(weight_desc), 'bias': self.driver.allocate_tensor(bias_desc) } def _get_temp_tensor(self, shape: tuple) -> str: """Allocate a temporary tensor""" tensor_id = f"decoder_temp_{self._counter}" self._counter += 1 desc = TensorDescriptor( shape=shape, dtype=getattr(DType, self.dtype.upper()), device=Device.VGPU, layout=Layout.ROW_MAJOR ) self._temp_tensors[tensor_id] = self.driver.allocate_tensor(desc) return tensor_id def _free_temp_tensor(self, tensor_id: str): """Free a temporary tensor""" if tensor_id in self._temp_tensors: self.driver.free_tensor(self._temp_tensors[tensor_id]) del self._temp_tensors[tensor_id] def __del__(self): """Clean up temporary tensors""" for tensor_id in list(self._temp_tensors.keys()): self._free_temp_tensor(tensor_id) def forward( self, hidden_states: Union[str, "HeliumTensor"], target_modality: ModalityType, encoder_hidden_states: Optional[Union[str, "HeliumTensor"]] = None, attention_mask: Optional[Union[str, "HeliumTensor"]] = None, encoder_attention_mask: Optional[Union[str, "HeliumTensor"]] = None, metadata: Optional[TensorMetadata] = None ) -> Union[str, "HeliumTensor"]: """ Forward pass of decoder block Args: hidden_states: Input tensor (B, S, H) encoder_hidden_states: Optional encoder output (B, S_enc, H) attention_mask: Optional attention mask for self-attention encoder_attention_mask: Optional mask for encoder-decoder attention Returns: Output tensor (B, S, H) """ residual = hidden_states # Self attention branch with self.stream: # Layer norm 1 hidden_states = self.ln1(hidden_states) # Self attention hidden_states = self.self_attention( hidden_states, attention_mask=attention_mask, causal_mask=True # Always use causal mask in decoder ) # Residual connection hidden_states = self.driver.add(hidden_states, residual) # Cross attention branch (if encoder present) if encoder_hidden_states is not None: residual = hidden_states with self.stream: # Layer norm 2 hidden_states = self.ln2(hidden_states) # Cross attention hidden_states = self.cross_attention( hidden_states, encoder_hidden_states, attention_mask=encoder_attention_mask ) # Residual connection hidden_states = self.driver.add(hidden_states, residual) # Feed-forward branch residual = hidden_states with self.stream: # Layer norm 3 hidden_states = self.ln3(hidden_states) # Feed-forward hidden_states = self.driver.matmul( hidden_states, self.ff1['weight'] ) hidden_states = self.driver.add(hidden_states, self.ff1['bias']) hidden_states = self.gelu(hidden_states) hidden_states = self.driver.matmul( hidden_states, self.ff2['weight'] ) hidden_states = self.driver.add(hidden_states, self.ff2['bias']) # Final residual hidden_states = self.driver.add(hidden_states, residual) # Project to target modality if target_modality not in self.output_projections: raise ValueError(f"No projection available for modality {target_modality}") output = self.output_projections[target_modality].forward(hidden_states) # Update metadata if provided if metadata is not None: metadata.modality = target_modality if target_modality == ModalityType.IMAGE: h, w = self.config.image_size metadata.spatial_dims = (h, w) metadata.channels = 3 elif target_modality == ModalityType.AUDIO: metadata.sampling_rate = self.config.audio_params.get("sampling_rate") elif target_modality == ModalityType.TEXT: metadata.sequence_length = output.shape[1] return output