from typing import Optional, Dict, List, Union, Tuple import numpy as np from enum import Enum from dataclasses import dataclass from .embedding import embedding_lookup, add_positional_encoding from .positional_encoding import sinusoidal_positional_encoding from .stack import transformer_stack from .layer_norm import layer_norm from .core.db_manager import HeliumDBManager from .broadcast import ModalityType, TensorMetadata class EncoderType(Enum): """Supported encoder architectures""" TEXT = "text" VISION = "vision" AUDIO = "audio" MULTIMODAL = "multimodal" @dataclass class ModalityConfig: """Configuration for specific modalities""" modality_type: ModalityType input_channels: int = 1 patch_size: Union[int, Tuple[int, ...]] = 16 sampling_rate: Optional[int] = None frame_rate: Optional[int] = None max_seq_len: int = 1024 use_positional: bool = True use_patch_embed: bool = False @dataclass class EncoderConfig: """Configuration for TransformerEncoder""" encoder_type: EncoderType hidden_dim: int num_layers: int num_heads: int modality_configs: Dict[ModalityType, ModalityConfig] vocab_size: Optional[int] = None # Only needed for text dropout_rate: float = 0.1 layer_norm_epsilon: float = 1e-5 initializer_range: float = 0.02 use_cache: bool = True use_fp16: bool = False fusion_type: str = "concatenate" # concatenate, add, or learnable def get_total_sequence_length(self) -> int: """Get total sequence length across all modalities""" return sum(config.max_seq_len for config in self.modality_configs.values()) class EncoderCache: """Cache for storing key/value states during inference""" def __init__(self): self.layer_states: List[Tuple[np.ndarray, np.ndarray]] = [] self.position_offset: int = 0 def update(self, layer_idx: int, key: np.ndarray, value: np.ndarray): if layer_idx >= len(self.layer_states): self.layer_states.append((key, value)) else: prev_k, prev_v = self.layer_states[layer_idx] self.layer_states[layer_idx] = ( np.concatenate([prev_k, key], axis=1), np.concatenate([prev_v, value], axis=1) ) class ModalityEncoder: """Base class for modality-specific encoders""" def __init__( self, config: ModalityConfig, hidden_dim: int, driver=None ): self.config = config self.hidden_dim = hidden_dim self.driver = driver def encode(self, x: np.ndarray) -> Tuple[np.ndarray, TensorMetadata]: """Convert input to embeddings with metadata""" raise NotImplementedError class VisionEncoder(ModalityEncoder): """Vision-specific encoder with patching""" def encode(self, x: np.ndarray) -> Tuple[np.ndarray, TensorMetadata]: # Apply patch embedding if self.config.use_patch_embed: B, C, H, W = x.shape P = self.config.patch_size x = x.reshape(B, C, H//P, P, W//P, P).transpose(0,2,4,1,3,5) x = x.reshape(B, (H//P)*(W//P), C*P*P) # Project to hidden dimension if hasattr(self.driver, 'linear'): x = self.driver.linear(x, self.hidden_dim) else: x = np.random.randn(*x.shape[:-1], self.hidden_dim) metadata = TensorMetadata( modality=ModalityType.VISION, shape=x.shape, dtype=x.dtype, channels=self.config.input_channels, spatial_dims=(H, W) if 'H' in locals() else None ) return x, metadata class AudioEncoder(ModalityEncoder): """Audio-specific encoder""" def encode(self, x: np.ndarray) -> Tuple[np.ndarray, TensorMetadata]: # Apply time-frequency transform if needed if hasattr(self.driver, 'stft'): x = self.driver.stft(x) metadata = TensorMetadata( modality=ModalityType.AUDIO, shape=x.shape, dtype=x.dtype, channels=self.config.input_channels, sampling_rate=self.config.sampling_rate ) return x, metadata class TextEncoder(ModalityEncoder): """Text-specific encoder""" def __init__(self, config: ModalityConfig, hidden_dim: int, vocab_size: int, embedding_weights: np.ndarray, driver=None): super().__init__(config, hidden_dim, driver) self.vocab_size = vocab_size self.embedding_weights = embedding_weights def encode(self, x: np.ndarray) -> Tuple[np.ndarray, TensorMetadata]: x = embedding_lookup(x, self.embedding_weights, driver=self.driver) metadata = TensorMetadata( modality=ModalityType.TEXT, shape=x.shape, dtype=x.dtype, sequence_length=x.shape[1] ) return x, metadata class TransformerEncoder: """ Multi-modal Transformer Encoder implementation with support for: - Multiple input modalities (text, vision, audio) - Cross-modal attention - Modality-specific processing - Inference caching - Mixed precision (FP16/FP32) - Parallel processing - Memory optimization """ def __init__( self, config: EncoderConfig, embedding_weights: Optional[np.ndarray] = None, block_weights_list: List[Dict] = None, driver=None, scheduler=None ): """ Initialize the multi-modal transformer encoder. Args: config: Encoder configuration with modality settings embedding_weights: Optional word embedding matrix for text block_weights_list: List of weight dictionaries for transformer blocks driver: Optional hardware driver for optimized computation scheduler: Optional scheduler for parallel processing """ self.validate_inputs(config, embedding_weights, block_weights_list) self.config = config self.driver = driver self.scheduler = scheduler # Initialize modality-specific encoders self.encoders = {} for modality, modal_config in config.modality_configs.items(): if modality == ModalityType.TEXT: if embedding_weights is None: raise ValueError("embedding_weights required for text modality") self.encoders[modality] = TextEncoder( modal_config, config.hidden_dim, config.vocab_size, self._prepare_weights(embedding_weights), driver ) elif modality == ModalityType.VISION: self.encoders[modality] = VisionEncoder( modal_config, config.hidden_dim, driver ) elif modality == ModalityType.AUDIO: self.encoders[modality] = AudioEncoder( modal_config, config.hidden_dim, driver ) # Initialize transformer blocks self.block_weights_list = [ self._prepare_weights(weights) for weights in (block_weights_list or []) ] # Initialize cached computations and fusion layer self._init_cached_computations() self._init_fusion_layer() def _init_cached_computations(self): """Initialize cached components for faster inference""" # Create positional encodings for each modality self.pos_encodings = {} dtype = np.float16 if self.config.use_fp16 else np.float32 for modality, modal_config in self.config.modality_configs.items(): if modal_config.use_positional: self.pos_encodings[modality] = sinusoidal_positional_encoding( modal_config.max_seq_len, self.config.hidden_dim, dtype=dtype ) # Precompute attention bias if supported if self.driver and hasattr(self.driver, 'precompute_attention_bias'): total_seq_len = self.config.get_total_sequence_length() self.cached_attention_bias = self.driver.precompute_attention_bias( total_seq_len ) else: self.cached_attention_bias = None def _init_fusion_layer(self): """Initialize multi-modal fusion layer""" if self.config.fusion_type == "learnable": num_modalities = len(self.config.modality_configs) if self.driver and hasattr(self.driver, 'create_parameter'): self.fusion_weights = self.driver.create_parameter( (num_modalities, 1, 1), dtype=np.float16 if self.config.use_fp16 else np.float32 ) else: self.fusion_weights = np.ones((num_modalities, 1, 1)) / num_modalities else: self.fusion_weights = None def _prepare_weights(self, weights: Union[np.ndarray, Dict]) -> Union[np.ndarray, Dict]: """Convert weights to appropriate precision""" if self.config.use_fp16: if isinstance(weights, np.ndarray): return weights.astype(np.float16) return {k: v.astype(np.float16) for k, v in weights.items()} return weights def _fuse_modalities( self, encoded_states: Dict[ModalityType, np.ndarray], encoded_metadata: Dict[ModalityType, TensorMetadata] ) -> Tuple[np.ndarray, TensorMetadata]: """ Fuse multiple modalities into a single representation Supports three fusion types: 1. concatenate: Concatenate along sequence dimension 2. add: Element-wise addition (requires same shape) 3. learnable: Weighted sum using learned weights """ modalities = list(encoded_states.keys()) if len(modalities) == 1: return encoded_states[modalities[0]], encoded_metadata[modalities[0]] if self.config.fusion_type == "concatenate": # Concatenate along sequence dimension fused = np.concatenate( [encoded_states[m] for m in modalities], axis=1 ) elif self.config.fusion_type == "add": # Verify shapes match shapes = [encoded_states[m].shape for m in modalities] if not all(s == shapes[0] for s in shapes): raise ValueError( f"All modalities must have same shape for addition fusion. Got {shapes}" ) fused = sum(encoded_states[m] for m in modalities) elif self.config.fusion_type == "learnable": # Apply learned weights weighted = [ encoded_states[m] * self.fusion_weights[i] for i, m in enumerate(modalities) ] fused = sum(weighted) else: raise ValueError(f"Unknown fusion type: {self.config.fusion_type}") # Create metadata for fused representation fused_metadata = TensorMetadata( modality=ModalityType.LATENT, shape=fused.shape, dtype=fused.dtype, channels=sum(m.channels for m in encoded_metadata.values()), sequence_length=fused.shape[1] ) return fused, fused_metadata @staticmethod def validate_inputs( config: EncoderConfig, embedding_weights: np.ndarray, block_weights_list: List[Dict] ): """Validate input parameters and weights""" if embedding_weights.shape != (config.vocab_size, config.hidden_dim): raise ValueError( f"Embedding weights shape {embedding_weights.shape} doesn't match " f"config (vocab_size={config.vocab_size}, hidden_dim={config.hidden_dim})" ) if len(block_weights_list) != config.num_layers: raise ValueError( f"Expected {config.num_layers} transformer blocks, got {len(block_weights_list)}" ) def create_attention_mask( self, input_shape: Tuple[int, int], past_length: int = 0 ) -> np.ndarray: """Create causal attention mask for autoregressive inference""" batch_size, seq_length = input_shape mask = np.ones((batch_size, 1, seq_length, seq_length + past_length)) # Create causal mask for autoregressive generation if past_length > 0: mask[:, :, :, :past_length] = 1.0 return mask def forward( self, inputs: Dict[ModalityType, np.ndarray], attention_mask: Optional[np.ndarray] = None, past_cache: Optional[EncoderCache] = None, return_cache: bool = False ) -> Union[np.ndarray, Tuple[np.ndarray, EncoderCache]]: """ Forward pass of the multi-modal encoder Args: inputs: Dictionary mapping modality types to input arrays attention_mask: Optional attention mask past_cache: Optional cached key/value states return_cache: Whether to return updated cache Returns: Encoded representations, optionally with cache """ # Encode each modality encoded_states = {} encoded_metadata = {} max_seq_len = 0 for modality, x in inputs.items(): if modality not in self.encoders: raise ValueError(f"No encoder configured for modality {modality}") # Encode input states, metadata = self.encoders[modality].encode(x) encoded_states[modality] = states encoded_metadata[modality] = metadata max_seq_len = max(max_seq_len, states.shape[1]) # Pad sequences to same length for modality in encoded_states: states = encoded_states[modality] if states.shape[1] < max_seq_len: pad_len = max_seq_len - states.shape[1] encoded_states[modality] = np.pad( states, ((0, 0), (0, pad_len), (0, 0)), mode='constant' ) # Add positional encodings for modality, states in encoded_states.items(): if modality in self.pos_encodings: pos_enc = self.pos_encodings[modality][:states.shape[1]] encoded_states[modality] = states + pos_enc # Create attention mask if not provided if attention_mask is None: attention_mask = self.create_attention_mask( (encoded_states[list(encoded_states.keys())[0]].shape[0], max_seq_len), past_length=past_cache.position_offset if past_cache else 0 ) """ Forward pass through the transformer encoder. Args: input_ids: Input token IDs of shape (batch_size, seq_len) attention_mask: Optional attention mask past_cache: Optional past key/value cache for inference return_cache: Whether to return updated cache Returns: output: Encoded representations cache: Updated cache if return_cache is True """ batch_size, seq_length = input_ids.shape if seq_length > self.config.max_seq_len: raise ValueError( f"Input sequence length {seq_length} exceeds maximum " f"sequence length {self.config.max_seq_len}" ) # Fuse modalities hidden_states, fused_metadata = self._fuse_modalities( encoded_states, encoded_metadata ) # Initialize cache for current forward pass current_cache = EncoderCache() if self.config.use_cache else None if current_cache: current_cache.modality_metadata = fused_metadata # Process through transformer stack with modality-aware attention hidden_states = transformer_stack( hidden_states, self.block_weights_list, self.config.num_heads, attention_mask=attention_mask, past_cache=past_cache, current_cache=current_cache, driver=self.driver, scheduler=self.scheduler, metadata=fused_metadata ) if return_cache: return hidden_states, current_cache return hidden_states def generate( self, input_ids: np.ndarray, max_length: int, temperature: float = 1.0, top_k: int = 50, top_p: float = 0.95 ) -> np.ndarray: """ Generate sequences autoregressively. Args: input_ids: Initial input tokens max_length: Maximum sequence length to generate temperature: Sampling temperature top_k: Number of top tokens to sample from top_p: Cumulative probability threshold for nucleus sampling Returns: generated_ids: Generated token sequences """ batch_size = input_ids.shape[0] generated_ids = [list(seq) for seq in input_ids] cache = EncoderCache() for _ in range(max_length - input_ids.shape[1]): # Forward pass with caching outputs, cache = self.forward( input_ids, past_cache=cache, return_cache=True ) # Get next token logits next_token_logits = outputs[:, -1, :] # Apply temperature next_token_logits = next_token_logits / temperature # Apply top-k filtering if top_k > 0: indices_to_remove = next_token_logits < np.partition( next_token_logits, -top_k, axis=-1 )[:, -top_k:].min(axis=-1, keepdims=True) next_token_logits[indices_to_remove] = -float('inf') # Apply top-p (nucleus) filtering if top_p < 1.0: sorted_logits = np.sort(next_token_logits, axis=-1)[:, ::-1] cumsum_probs = np.cumsum(np.exp(sorted_logits), axis=-1) mask = cumsum_probs > top_p mask[:, 1:] = mask[:, :-1].copy() mask[:, 0] = 0 indices_to_remove = next_token_logits < np.min( sorted_logits[mask], axis=-1, keepdims=True ) next_token_logits[indices_to_remove] = -float('inf') # Sample next tokens probs = np.exp(next_token_logits) probs = probs / np.sum(probs, axis=-1, keepdims=True) next_tokens = np.array([ np.random.choice(self.config.vocab_size, p=p) for p in probs ]) # Update generated sequences for i in range(batch_size): generated_ids[i].append(next_tokens[i]) # Update input_ids for next iteration input_ids = next_tokens[:, np.newaxis] return np.array(generated_ids)