from __future__ import annotations from typing import Dict, List, Union, Optional, Any, Tuple import numpy as np from dataclasses import dataclass from enum import Enum import warnings import json import hashlib import logging from functools import lru_cache from pathlib import Path # Import local dependencies from .broadcast import ModalityType from .attention_utils import AttentionState from .core.db_manager import HeliumDBManager from .virtual_gpu_device import VirtualGPUDevice # Initialize virtual GPU device pool _gpu_devices: Dict[str, VirtualGPUDevice] = {} # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelType(Enum): """Supported model architectures""" GPT2 = "gpt2" BERT = "bert" T5 = "t5" LLAMA = "llama" MISTRAL = "mistral" FALCON = "falcon" @dataclass class ModelConfig: """Universal configuration for transformer models""" model_type: ModelType num_layers: int num_heads: int hidden_dim: int vocab_size: int max_seq_len: int intermediate_size: Optional[int] = None layer_norm_epsilon: float = 1e-5 initializer_range: float = 0.02 use_cache: bool = True use_fp16: bool = False rotary_dim: Optional[int] = None # For models with rotary embeddings vocab_padding_size: Optional[int] = None # For vocab size optimization class CacheManager: """Manages caching for model utilities""" def __init__(self): self.db = HeliumDBManager.get_instance() def _compute_key(self, data: Any, prefix: str) -> str: """Compute cache key for data""" if isinstance(data, np.ndarray): return f"{prefix}_{hashlib.sha256(data.tobytes()).hexdigest()}" return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}" def get(self, key: str) -> Optional[Any]: """Get cached data""" return self.db.get_activation(key) def set(self, key: str, value: Any, metadata: Dict): """Cache data""" self.db.set_activation(key, value, metadata) class MaskGenerator: """Optimized attention mask generator""" def __init__(self, use_cache: bool = True): self.cache_manager = CacheManager() if use_cache else None @lru_cache(maxsize=128) def create_causal_mask(self, seq_len: int, dtype: np.dtype = np.bool_, device = None) -> np.ndarray: """Create causal (autoregressive) attention mask. Uses caching for common sequence lengths. """ # Check cache first if self.cache_manager: cache_key = self._compute_key((seq_len, str(dtype)), "causal_mask") cached_mask = self.cache_manager.get(cache_key) if cached_mask is not None: return cached_mask # Create mask mask = np.tril(np.ones((seq_len, seq_len), dtype=dtype)) mask = mask[np.newaxis, np.newaxis, :, :] # Cache if enabled if self.cache_manager: metadata = {"type": "causal_mask", "seq_len": seq_len, "dtype": str(dtype)} self.cache_manager.set(cache_key, mask, metadata) # Return, moving to device if needed return mask if device is None else device.to_gpu(mask) def _compute_key(self, data: Any, prefix: str) -> str: """Compute cache key for data""" if isinstance(data, tuple): data = "_".join(str(x) for x in data) return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}" def split_heads(x: Union[str, "HeliumTensor"], num_heads: int, driver, modality: Optional[ModalityType] = None) -> Union[str, "HeliumTensor"]: """Split hidden dim into multiple heads""" if isinstance(x, str): x = driver.get_tensor(x) batch_size, seq_len, hidden_dim = x.shape head_dim = hidden_dim // num_heads # Reshape and transpose x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim)) x = driver.transpose(x, (0, 2, 1, 3)) return x def apply_rotary_embedding(x: Union[str, HeliumTensor], seq_len: int, head_dim: int, driver) -> Union[str, HeliumTensor]: """Apply rotary positional embeddings""" if isinstance(x, str): x = driver.get_tensor(x) # Generate position indices pos = np.arange(seq_len) # Generate frequencies freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim) angles = pos[:, None] * freqs[None, :] # Generate rotation matrix elements cos = np.cos(angles).reshape(seq_len, -1) sin = np.sin(angles).reshape(seq_len, -1) # Move to device cos = driver.to_gpu(cos) sin = driver.to_gpu(sin) # Apply rotations x_rot = driver.matmul(x, cos) - driver.matmul(x, sin) x = driver.add(x, x_rot) return x def fuse_cross_modal_attention(q: Union[str, HeliumTensor], k: Union[str, HeliumTensor], v: Union[str, HeliumTensor], q_modality: ModalityType, kv_modality: ModalityType, fusion_type: str, driver, state: AttentionState) -> Tuple[Union[str, HeliumTensor], Union[str, HeliumTensor], Union[str, HeliumTensor]]: """Fuse cross-modal attention patterns""" if isinstance(driver, str): driver = get_gpu_device(driver) if fusion_type == "additive": # Simple additive fusion q = driver.add(q, k) k = q elif fusion_type == "multiplicative": # Element-wise multiplication q = driver.mul(q, k) k = q elif fusion_type == "gated": # Gated fusion with learned parameters gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None))) q = driver.add( driver.mul(gate, q), driver.mul(driver.sub(1.0, gate), k) ) k = q return q, k, v import numpy as np from enum import Enum from dataclasses import dataclass import warnings from .core.db_manager import HeliumDBManager import json import hashlib from pathlib import Path import torch # For tensor conversion utilities import logging from functools import lru_cache # Import local dependencies from typing import TYPE_CHECKING if TYPE_CHECKING: from .tensor import HeliumTensor from .broadcast import ModalityType from .attention_utils import AttentionState # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) class ModelType(Enum): """Supported model architectures""" GPT2 = "gpt2" BERT = "bert" T5 = "t5" LLAMA = "llama" MISTRAL = "mistral" FALCON = "falcon" @dataclass class ModelConfig: """Universal configuration for transformer models""" model_type: ModelType num_layers: int num_heads: int hidden_dim: int vocab_size: int max_seq_len: int intermediate_size: Optional[int] = None layer_norm_epsilon: float = 1e-5 initializer_range: float = 0.02 use_cache: bool = True use_fp16: bool = False rotary_dim: Optional[int] = None # For models with rotary embeddings vocab_padding_size: Optional[int] = None # For vocab size optimization class CacheManager: """Manages caching for model utilities""" def __init__(self): self.db = HeliumDBManager.get_instance() def _compute_key(self, data: Any, prefix: str) -> str: """Compute cache key for data""" if isinstance(data, np.ndarray): return f"{prefix}_{hashlib.sha256(data.tobytes()).hexdigest()}" return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}" def get(self, key: str) -> Optional[Any]: """Get cached data""" return self.db.get_activation(key) def set(self, key: str, value: Any, metadata: Dict): """Cache data""" self.db.set_activation(key, value, metadata) class MaskGenerator: """Optimized attention mask generator""" def __init__(self, use_cache: bool = True): self.cache_manager = CacheManager() if use_cache else None @lru_cache(maxsize=128) def create_causal_mask(self, seq_len: int, dtype: np.dtype = np.bool_, device = None) -> np.ndarray: """Create causal (autoregressive) attention mask. Uses caching for common sequence lengths. Args: seq_len: Sequence length dtype: Data type for mask device: Device to place mask on Returns: mask: Shape (1, 1, seq_len, seq_len) attention mask """ if self.cache_manager: cache_key = self.cache_manager._compute_key((seq_len, str(dtype)), "causal_mask") cached_mask = self.cache_manager.get(cache_key) if cached_mask is not None: return cached_mask mask = np.tril(np.ones((seq_len, seq_len), dtype=dtype)) mask = mask[np.newaxis, np.newaxis, :, :] if self.cache_manager: metadata = { "type": "causal_mask", "seq_len": seq_len, "dtype": str(dtype) } self.cache_manager.set(cache_key, mask, metadata) return mask if device is None else device.to_gpu(mask) def split_heads( x: Union[str, "HeliumTensor"], num_heads: int, driver, modality: Optional["ModalityType"] = None ) -> Union[str, "HeliumTensor"]: """Split hidden dim into multiple heads""" if isinstance(x, str): x = driver.get_tensor(x) batch_size, seq_len, hidden_dim = x.shape head_dim = hidden_dim // num_heads # Reshape and transpose x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim)) x = driver.transpose(x, (0, 2, 1, 3)) return x def apply_rotary_embedding( x: Union[str, "HeliumTensor"], seq_len: int, head_dim: int, driver ) -> Union[str, "HeliumTensor"]: """Apply rotary positional embeddings""" if isinstance(x, str): x = driver.get_tensor(x) # Generate position indices pos = np.arange(seq_len) # Generate frequencies freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim) angles = pos[:, None] * freqs[None, :] # Generate rotation matrix elements cos = np.cos(angles).reshape(seq_len, -1) sin = np.sin(angles).reshape(seq_len, -1) # Move to device cos = driver.to_gpu(cos) sin = driver.to_gpu(sin) # Apply rotations x_rot = driver.matmul(x, cos) - driver.matmul(x, sin) x = driver.add(x, x_rot) return x def fuse_cross_modal_attention( q: Union[str, "HeliumTensor"], k: Union[str, "HeliumTensor"], v: Union[str, "HeliumTensor"], q_modality: "ModalityType", kv_modality: "ModalityType", fusion_type: str, driver, state: "AttentionState" ) -> Tuple[Union[str, "HeliumTensor"], Union[str, "HeliumTensor"], Union[str, "HeliumTensor"]]: """Fuse cross-modal attention patterns""" if fusion_type == "additive": # Simple additive fusion q = driver.add(q, k) k = q elif fusion_type == "multiplicative": # Element-wise multiplication q = driver.mul(q, k) k = q elif fusion_type == "gated": # Gated fusion with learned parameters gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None))) q = driver.add( driver.mul(gate, q), driver.mul(driver.sub(1.0, gate), k) ) k = q return q, k, v class WeightMapper: """Optimized weight mapping utility for different model architectures""" def __init__(self, use_cache: bool = True): self.cache_manager = CacheManager() if use_cache else None def map_weights( self, hf_weights: Dict[str, np.ndarray], config: ModelConfig ) -> List[Dict[str, np.ndarray]]: """ Map weights based on model type Args: hf_weights: HuggingFace weight dictionary config: Model configuration Returns: List of block weight dictionaries """ if self.cache_manager: cache_key = self.cache_manager._compute_key( (list(hf_weights.keys()), config.model_type.value), "weight_mapping" ) cached_mapping = self.cache_manager.get(cache_key) if cached_mapping is not None: return cached_mapping mapping_funcs = { ModelType.GPT2: self._map_gpt2_weights, ModelType.BERT: self._map_bert_weights, ModelType.T5: self._map_t5_weights, ModelType.LLAMA: self._map_llama_weights, ModelType.MISTRAL: self._map_mistral_weights, ModelType.FALCON: self._map_falcon_weights } mapper = mapping_funcs.get(config.model_type) if not mapper: raise ValueError(f"Unsupported model type: {config.model_type}") result = mapper(hf_weights, config) if self.cache_manager: self.cache_manager.set( cache_key, result, {'model_type': config.model_type.value} ) return result def _map_gpt2_weights( self, hf_weights: Dict[str, np.ndarray], config: ModelConfig, prefix: str = 'transformer.h.' ) -> List[Dict[str, np.ndarray]]: """Map GPT-2 weights with optimizations""" block_weights_list = [] try: for i in range(config.num_layers): block = {} # Layer normalization weights block['ln1.weight'] = hf_weights[f'{prefix}{i}.ln_1.weight'] block['ln1.bias'] = hf_weights[f'{prefix}{i}.ln_1.bias'] # Attention weights with efficient splitting attn_weight = hf_weights[f'{prefix}{i}.attn.c_attn.weight'] split_size = attn_weight.shape[0] // 3 block['attn.q_proj.weight'] = attn_weight[:, :split_size] block['attn.k_proj.weight'] = attn_weight[:, split_size:2*split_size] block['attn.v_proj.weight'] = attn_weight[:, 2*split_size:] # Output projection block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.attn.c_proj.weight'] # Second layer norm block['ln2.weight'] = hf_weights[f'{prefix}{i}.ln_2.weight'] block['ln2.bias'] = hf_weights[f'{prefix}{i}.ln_2.bias'] # Feed-forward weights block['ff1.weight'] = hf_weights[f'{prefix}{i}.mlp.c_fc.weight'] block['ff1.bias'] = hf_weights[f'{prefix}{i}.mlp.c_fc.bias'] block['ff2.weight'] = hf_weights[f'{prefix}{i}.mlp.c_proj.weight'] block['ff2.bias'] = hf_weights[f'{prefix}{i}.mlp.c_proj.bias'] # Optional rotary embeddings for newer variants if config.rotary_dim: if f'{prefix}{i}.attn.rotary_emb.inv_freq' in hf_weights: block['attn.rotary_emb.inv_freq'] = hf_weights[f'{prefix}{i}.attn.rotary_emb.inv_freq'] block_weights_list.append(block) except KeyError as e: logger.error(f"Failed to map GPT-2 weights: {str(e)}") raise ValueError(f"Missing required weight: {str(e)}") return block_weights_list def _map_bert_weights( self, hf_weights: Dict[str, np.ndarray], config: ModelConfig, prefix: str = 'bert.encoder.layer.' ) -> List[Dict[str, np.ndarray]]: """Map BERT weights with optimizations""" block_weights_list = [] try: for i in range(config.num_layers): block = {} # Layer normalization weights block['ln1.weight'] = hf_weights[f'{prefix}{i}.attention.output.LayerNorm.weight'] block['ln1.bias'] = hf_weights[f'{prefix}{i}.attention.output.LayerNorm.bias'] # Attention weights block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.query.weight'] block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.key.weight'] block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.value.weight'] block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.attention.output.dense.weight'] # Second layer norm block['ln2.weight'] = hf_weights[f'{prefix}{i}.output.LayerNorm.weight'] block['ln2.bias'] = hf_weights[f'{prefix}{i}.output.LayerNorm.bias'] # Feed-forward weights block['ff1.weight'] = hf_weights[f'{prefix}{i}.intermediate.dense.weight'] block['ff1.bias'] = hf_weights[f'{prefix}{i}.intermediate.dense.bias'] block['ff2.weight'] = hf_weights[f'{prefix}{i}.output.dense.weight'] block['ff2.bias'] = hf_weights[f'{prefix}{i}.output.dense.bias'] # Add position embeddings if available if i == 0 and 'bert.embeddings.position_embeddings.weight' in hf_weights: block['position_embeddings'] = hf_weights['bert.embeddings.position_embeddings.weight'] block_weights_list.append(block) except KeyError as e: logger.error(f"Failed to map BERT weights: {str(e)}") raise ValueError(f"Missing required weight: {str(e)}") return block_weights_list def _map_t5_weights( self, hf_weights: Dict[str, np.ndarray], config: ModelConfig, prefix: str = 'encoder.block.' ) -> List[Dict[str, np.ndarray]]: """Map T5 weights with optimizations""" block_weights_list = [] try: for i in range(config.num_layers): block = {} # Layer normalization block['ln1.weight'] = hf_weights[f'{prefix}{i}.layer.0.layer_norm.weight'] # Attention weights block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.q.weight'] block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.k.weight'] block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.v.weight'] block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.o.weight'] # Second layer norm block['ln2.weight'] = hf_weights[f'{prefix}{i}.layer.1.layer_norm.weight'] # Feed-forward weights block['ff1.weight'] = hf_weights[f'{prefix}{i}.layer.1.DenseReluDense.wi.weight'] block['ff2.weight'] = hf_weights[f'{prefix}{i}.layer.1.DenseReluDense.wo.weight'] # Relative position bias if available if f'{prefix}{i}.layer.0.SelfAttention.relative_attention_bias' in hf_weights: block['attn.relative_attention_bias'] = hf_weights[ f'{prefix}{i}.layer.0.SelfAttention.relative_attention_bias' ] block_weights_list.append(block) except KeyError as e: logger.error(f"Failed to map T5 weights: {str(e)}") raise ValueError(f"Missing required weight: {str(e)}") return block_weights_list def _map_llama_weights( self, hf_weights: Dict[str, np.ndarray], config: ModelConfig, prefix: str = 'model.layers.' ) -> List[Dict[str, np.ndarray]]: """Map LLaMA weights with optimizations""" block_weights_list = [] try: for i in range(config.num_layers): block = {} # Attention weights block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.q_proj.weight'] block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.k_proj.weight'] block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.v_proj.weight'] block['attn.o_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.o_proj.weight'] # Rotary embeddings if config.rotary_dim: block['attn.rotary_emb.inv_freq'] = hf_weights.get( f'{prefix}{i}.self_attn.rotary_emb.inv_freq' ) # RMSNorm weights block['input_layernorm.weight'] = hf_weights[f'{prefix}{i}.input_layernorm.weight'] block['post_attention_layernorm.weight'] = hf_weights[f'{prefix}{i}.post_attention_layernorm.weight'] # Feed-forward weights block['mlp.gate_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.gate_proj.weight'] block['mlp.up_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.up_proj.weight'] block['mlp.down_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.down_proj.weight'] block_weights_list.append(block) except KeyError as e: logger.error(f"Failed to map LLaMA weights: {str(e)}") raise ValueError(f"Missing required weight: {str(e)}") return block_weights_list class ConfigParser: """Enhanced configuration parser with caching""" def __init__(self, use_cache: bool = True): self.cache_manager = CacheManager() if use_cache else None def parse_config( self, config: Dict[str, Any], model_type: Optional[ModelType] = None ) -> ModelConfig: """ Parse model configuration with caching Args: config: Configuration dictionary model_type: Optional model type override Returns: ModelConfig instance """ if self.cache_manager: cache_key = self.cache_manager._compute_key(config, "config_parsing") cached_config = self.cache_manager.get(cache_key) if cached_config is not None: return cached_config # Detect model type if not provided if not model_type: model_type = self._detect_model_type(config) # Parse based on model type parsed_config = self._parse_by_type(config, model_type) if self.cache_manager: self.cache_manager.set( cache_key, parsed_config, {'model_type': model_type.value} ) return parsed_config def _detect_model_type(self, config: Dict[str, Any]) -> ModelType: """Detect model type from config""" if 'n_layer' in config: return ModelType.GPT2 elif 'num_hidden_layers' in config: return ModelType.BERT elif 'd_model' in config: return ModelType.T5 elif 'num_key_value_heads' in config: return ModelType.LLAMA elif 'sliding_window' in config: return ModelType.MISTRAL elif 'multi_query_group_num' in config: return ModelType.FALCON else: raise ValueError("Unable to detect model type from config") def _parse_by_type( self, config: Dict[str, Any], model_type: ModelType ) -> ModelConfig: """Parse config based on model type""" if model_type == ModelType.GPT2: return ModelConfig( model_type=model_type, num_layers=config['n_layer'], num_heads=config['n_head'], hidden_dim=config['n_embd'], vocab_size=config['vocab_size'], max_seq_len=config.get('n_positions', 1024), intermediate_size=config.get('n_inner', None), layer_norm_epsilon=config.get('layer_norm_epsilon', 1e-5) ) elif model_type == ModelType.BERT: return ModelConfig( model_type=model_type, num_layers=config['num_hidden_layers'], num_heads=config['num_attention_heads'], hidden_dim=config['hidden_size'], vocab_size=config['vocab_size'], max_seq_len=config.get('max_position_embeddings', 512), intermediate_size=config.get('intermediate_size', None), layer_norm_epsilon=config.get('layer_norm_eps', 1e-12) ) elif model_type == ModelType.T5: return ModelConfig( model_type=model_type, num_layers=config['num_layers'], num_heads=config['num_heads'], hidden_dim=config['d_model'], vocab_size=config['vocab_size'], max_seq_len=config.get('n_positions', 512), intermediate_size=config.get('d_ff', None), layer_norm_epsilon=config.get('layer_norm_epsilon', 1e-6) ) elif model_type == ModelType.LLAMA: return ModelConfig( model_type=model_type, num_layers=config['num_hidden_layers'], num_heads=config['num_attention_heads'], hidden_dim=config['hidden_size'], vocab_size=config['vocab_size'], max_seq_len=config.get('max_position_embeddings', 2048), intermediate_size=config.get('intermediate_size', None), layer_norm_epsilon=config.get('rms_norm_eps', 1e-6), rotary_dim=config.get('rotary_dim', None) ) else: raise ValueError(f"Unsupported model type: {model_type}") # Legacy function for backward compatibility def parse_hf_config(config: Dict[str, Any]) -> Dict[str, Any]: """Legacy config parsing interface""" warnings.warn( "parse_hf_config is deprecated, use ConfigParser class instead", DeprecationWarning ) parser = ConfigParser(use_cache=True) parsed_config = parser.parse_config(config) return { 'num_layers': parsed_config.num_layers, 'num_heads': parsed_config.num_heads, 'hidden_dim': parsed_config.hidden_dim, 'vocab_size': parsed_config.vocab_size, 'max_seq_len': parsed_config.max_seq_len }