INV / helium /utils.py
Fred808's picture
Upload 256 files
7a0c684 verified
from __future__ import annotations
from typing import Dict, List, Union, Optional, Any, Tuple
import numpy as np
from dataclasses import dataclass
from enum import Enum
import warnings
import json
import hashlib
import logging
from functools import lru_cache
from pathlib import Path
# Import local dependencies
from .broadcast import ModalityType
from .attention_utils import AttentionState
from .core.db_manager import HeliumDBManager
from .virtual_gpu_device import VirtualGPUDevice
# Initialize virtual GPU device pool
_gpu_devices: Dict[str, VirtualGPUDevice] = {}
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelType(Enum):
"""Supported model architectures"""
GPT2 = "gpt2"
BERT = "bert"
T5 = "t5"
LLAMA = "llama"
MISTRAL = "mistral"
FALCON = "falcon"
@dataclass
class ModelConfig:
"""Universal configuration for transformer models"""
model_type: ModelType
num_layers: int
num_heads: int
hidden_dim: int
vocab_size: int
max_seq_len: int
intermediate_size: Optional[int] = None
layer_norm_epsilon: float = 1e-5
initializer_range: float = 0.02
use_cache: bool = True
use_fp16: bool = False
rotary_dim: Optional[int] = None # For models with rotary embeddings
vocab_padding_size: Optional[int] = None # For vocab size optimization
class CacheManager:
"""Manages caching for model utilities"""
def __init__(self):
self.db = HeliumDBManager.get_instance()
def _compute_key(self, data: Any, prefix: str) -> str:
"""Compute cache key for data"""
if isinstance(data, np.ndarray):
return f"{prefix}_{hashlib.sha256(data.tobytes()).hexdigest()}"
return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}"
def get(self, key: str) -> Optional[Any]:
"""Get cached data"""
return self.db.get_activation(key)
def set(self, key: str, value: Any, metadata: Dict):
"""Cache data"""
self.db.set_activation(key, value, metadata)
class MaskGenerator:
"""Optimized attention mask generator"""
def __init__(self, use_cache: bool = True):
self.cache_manager = CacheManager() if use_cache else None
@lru_cache(maxsize=128)
def create_causal_mask(self, seq_len: int, dtype: np.dtype = np.bool_, device = None) -> np.ndarray:
"""Create causal (autoregressive) attention mask.
Uses caching for common sequence lengths.
"""
# Check cache first
if self.cache_manager:
cache_key = self._compute_key((seq_len, str(dtype)), "causal_mask")
cached_mask = self.cache_manager.get(cache_key)
if cached_mask is not None:
return cached_mask
# Create mask
mask = np.tril(np.ones((seq_len, seq_len), dtype=dtype))
mask = mask[np.newaxis, np.newaxis, :, :]
# Cache if enabled
if self.cache_manager:
metadata = {"type": "causal_mask", "seq_len": seq_len, "dtype": str(dtype)}
self.cache_manager.set(cache_key, mask, metadata)
# Return, moving to device if needed
return mask if device is None else device.to_gpu(mask)
def _compute_key(self, data: Any, prefix: str) -> str:
"""Compute cache key for data"""
if isinstance(data, tuple):
data = "_".join(str(x) for x in data)
return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}"
def split_heads(x: Union[str, "HeliumTensor"],
num_heads: int,
driver,
modality: Optional[ModalityType] = None) -> Union[str, "HeliumTensor"]:
"""Split hidden dim into multiple heads"""
if isinstance(x, str):
x = driver.get_tensor(x)
batch_size, seq_len, hidden_dim = x.shape
head_dim = hidden_dim // num_heads
# Reshape and transpose
x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim))
x = driver.transpose(x, (0, 2, 1, 3))
return x
def apply_rotary_embedding(x: Union[str, HeliumTensor],
seq_len: int,
head_dim: int,
driver) -> Union[str, HeliumTensor]:
"""Apply rotary positional embeddings"""
if isinstance(x, str):
x = driver.get_tensor(x)
# Generate position indices
pos = np.arange(seq_len)
# Generate frequencies
freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim)
angles = pos[:, None] * freqs[None, :]
# Generate rotation matrix elements
cos = np.cos(angles).reshape(seq_len, -1)
sin = np.sin(angles).reshape(seq_len, -1)
# Move to device
cos = driver.to_gpu(cos)
sin = driver.to_gpu(sin)
# Apply rotations
x_rot = driver.matmul(x, cos) - driver.matmul(x, sin)
x = driver.add(x, x_rot)
return x
def fuse_cross_modal_attention(q: Union[str, HeliumTensor],
k: Union[str, HeliumTensor],
v: Union[str, HeliumTensor],
q_modality: ModalityType,
kv_modality: ModalityType,
fusion_type: str,
driver,
state: AttentionState) -> Tuple[Union[str, HeliumTensor],
Union[str, HeliumTensor],
Union[str, HeliumTensor]]:
"""Fuse cross-modal attention patterns"""
if isinstance(driver, str):
driver = get_gpu_device(driver)
if fusion_type == "additive":
# Simple additive fusion
q = driver.add(q, k)
k = q
elif fusion_type == "multiplicative":
# Element-wise multiplication
q = driver.mul(q, k)
k = q
elif fusion_type == "gated":
# Gated fusion with learned parameters
gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None)))
q = driver.add(
driver.mul(gate, q),
driver.mul(driver.sub(1.0, gate), k)
)
k = q
return q, k, v
import numpy as np
from enum import Enum
from dataclasses import dataclass
import warnings
from .core.db_manager import HeliumDBManager
import json
import hashlib
from pathlib import Path
import torch # For tensor conversion utilities
import logging
from functools import lru_cache
# Import local dependencies
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from .tensor import HeliumTensor
from .broadcast import ModalityType
from .attention_utils import AttentionState
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
class ModelType(Enum):
"""Supported model architectures"""
GPT2 = "gpt2"
BERT = "bert"
T5 = "t5"
LLAMA = "llama"
MISTRAL = "mistral"
FALCON = "falcon"
@dataclass
class ModelConfig:
"""Universal configuration for transformer models"""
model_type: ModelType
num_layers: int
num_heads: int
hidden_dim: int
vocab_size: int
max_seq_len: int
intermediate_size: Optional[int] = None
layer_norm_epsilon: float = 1e-5
initializer_range: float = 0.02
use_cache: bool = True
use_fp16: bool = False
rotary_dim: Optional[int] = None # For models with rotary embeddings
vocab_padding_size: Optional[int] = None # For vocab size optimization
class CacheManager:
"""Manages caching for model utilities"""
def __init__(self):
self.db = HeliumDBManager.get_instance()
def _compute_key(self, data: Any, prefix: str) -> str:
"""Compute cache key for data"""
if isinstance(data, np.ndarray):
return f"{prefix}_{hashlib.sha256(data.tobytes()).hexdigest()}"
return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}"
def get(self, key: str) -> Optional[Any]:
"""Get cached data"""
return self.db.get_activation(key)
def set(self, key: str, value: Any, metadata: Dict):
"""Cache data"""
self.db.set_activation(key, value, metadata)
class MaskGenerator:
"""Optimized attention mask generator"""
def __init__(self, use_cache: bool = True):
self.cache_manager = CacheManager() if use_cache else None
@lru_cache(maxsize=128)
def create_causal_mask(self, seq_len: int, dtype: np.dtype = np.bool_, device = None) -> np.ndarray:
"""Create causal (autoregressive) attention mask.
Uses caching for common sequence lengths.
Args:
seq_len: Sequence length
dtype: Data type for mask
device: Device to place mask on
Returns:
mask: Shape (1, 1, seq_len, seq_len) attention mask
"""
if self.cache_manager:
cache_key = self.cache_manager._compute_key((seq_len, str(dtype)), "causal_mask")
cached_mask = self.cache_manager.get(cache_key)
if cached_mask is not None:
return cached_mask
mask = np.tril(np.ones((seq_len, seq_len), dtype=dtype))
mask = mask[np.newaxis, np.newaxis, :, :]
if self.cache_manager:
metadata = {
"type": "causal_mask",
"seq_len": seq_len,
"dtype": str(dtype)
}
self.cache_manager.set(cache_key, mask, metadata)
return mask if device is None else device.to_gpu(mask)
def split_heads(
x: Union[str, "HeliumTensor"],
num_heads: int,
driver,
modality: Optional["ModalityType"] = None
) -> Union[str, "HeliumTensor"]:
"""Split hidden dim into multiple heads"""
if isinstance(x, str):
x = driver.get_tensor(x)
batch_size, seq_len, hidden_dim = x.shape
head_dim = hidden_dim // num_heads
# Reshape and transpose
x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim))
x = driver.transpose(x, (0, 2, 1, 3))
return x
def apply_rotary_embedding(
x: Union[str, "HeliumTensor"],
seq_len: int,
head_dim: int,
driver
) -> Union[str, "HeliumTensor"]:
"""Apply rotary positional embeddings"""
if isinstance(x, str):
x = driver.get_tensor(x)
# Generate position indices
pos = np.arange(seq_len)
# Generate frequencies
freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim)
angles = pos[:, None] * freqs[None, :]
# Generate rotation matrix elements
cos = np.cos(angles).reshape(seq_len, -1)
sin = np.sin(angles).reshape(seq_len, -1)
# Move to device
cos = driver.to_gpu(cos)
sin = driver.to_gpu(sin)
# Apply rotations
x_rot = driver.matmul(x, cos) - driver.matmul(x, sin)
x = driver.add(x, x_rot)
return x
def fuse_cross_modal_attention(
q: Union[str, "HeliumTensor"],
k: Union[str, "HeliumTensor"],
v: Union[str, "HeliumTensor"],
q_modality: "ModalityType",
kv_modality: "ModalityType",
fusion_type: str,
driver,
state: "AttentionState"
) -> Tuple[Union[str, "HeliumTensor"], Union[str, "HeliumTensor"], Union[str, "HeliumTensor"]]:
"""Fuse cross-modal attention patterns"""
if fusion_type == "additive":
# Simple additive fusion
q = driver.add(q, k)
k = q
elif fusion_type == "multiplicative":
# Element-wise multiplication
q = driver.mul(q, k)
k = q
elif fusion_type == "gated":
# Gated fusion with learned parameters
gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None)))
q = driver.add(
driver.mul(gate, q),
driver.mul(driver.sub(1.0, gate), k)
)
k = q
return q, k, v
class WeightMapper:
"""Optimized weight mapping utility for different model architectures"""
def __init__(self, use_cache: bool = True):
self.cache_manager = CacheManager() if use_cache else None
def map_weights(
self,
hf_weights: Dict[str, np.ndarray],
config: ModelConfig
) -> List[Dict[str, np.ndarray]]:
"""
Map weights based on model type
Args:
hf_weights: HuggingFace weight dictionary
config: Model configuration
Returns:
List of block weight dictionaries
"""
if self.cache_manager:
cache_key = self.cache_manager._compute_key(
(list(hf_weights.keys()), config.model_type.value),
"weight_mapping"
)
cached_mapping = self.cache_manager.get(cache_key)
if cached_mapping is not None:
return cached_mapping
mapping_funcs = {
ModelType.GPT2: self._map_gpt2_weights,
ModelType.BERT: self._map_bert_weights,
ModelType.T5: self._map_t5_weights,
ModelType.LLAMA: self._map_llama_weights,
ModelType.MISTRAL: self._map_mistral_weights,
ModelType.FALCON: self._map_falcon_weights
}
mapper = mapping_funcs.get(config.model_type)
if not mapper:
raise ValueError(f"Unsupported model type: {config.model_type}")
result = mapper(hf_weights, config)
if self.cache_manager:
self.cache_manager.set(
cache_key,
result,
{'model_type': config.model_type.value}
)
return result
def _map_gpt2_weights(
self,
hf_weights: Dict[str, np.ndarray],
config: ModelConfig,
prefix: str = 'transformer.h.'
) -> List[Dict[str, np.ndarray]]:
"""Map GPT-2 weights with optimizations"""
block_weights_list = []
try:
for i in range(config.num_layers):
block = {}
# Layer normalization weights
block['ln1.weight'] = hf_weights[f'{prefix}{i}.ln_1.weight']
block['ln1.bias'] = hf_weights[f'{prefix}{i}.ln_1.bias']
# Attention weights with efficient splitting
attn_weight = hf_weights[f'{prefix}{i}.attn.c_attn.weight']
split_size = attn_weight.shape[0] // 3
block['attn.q_proj.weight'] = attn_weight[:, :split_size]
block['attn.k_proj.weight'] = attn_weight[:, split_size:2*split_size]
block['attn.v_proj.weight'] = attn_weight[:, 2*split_size:]
# Output projection
block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.attn.c_proj.weight']
# Second layer norm
block['ln2.weight'] = hf_weights[f'{prefix}{i}.ln_2.weight']
block['ln2.bias'] = hf_weights[f'{prefix}{i}.ln_2.bias']
# Feed-forward weights
block['ff1.weight'] = hf_weights[f'{prefix}{i}.mlp.c_fc.weight']
block['ff1.bias'] = hf_weights[f'{prefix}{i}.mlp.c_fc.bias']
block['ff2.weight'] = hf_weights[f'{prefix}{i}.mlp.c_proj.weight']
block['ff2.bias'] = hf_weights[f'{prefix}{i}.mlp.c_proj.bias']
# Optional rotary embeddings for newer variants
if config.rotary_dim:
if f'{prefix}{i}.attn.rotary_emb.inv_freq' in hf_weights:
block['attn.rotary_emb.inv_freq'] = hf_weights[f'{prefix}{i}.attn.rotary_emb.inv_freq']
block_weights_list.append(block)
except KeyError as e:
logger.error(f"Failed to map GPT-2 weights: {str(e)}")
raise ValueError(f"Missing required weight: {str(e)}")
return block_weights_list
def _map_bert_weights(
self,
hf_weights: Dict[str, np.ndarray],
config: ModelConfig,
prefix: str = 'bert.encoder.layer.'
) -> List[Dict[str, np.ndarray]]:
"""Map BERT weights with optimizations"""
block_weights_list = []
try:
for i in range(config.num_layers):
block = {}
# Layer normalization weights
block['ln1.weight'] = hf_weights[f'{prefix}{i}.attention.output.LayerNorm.weight']
block['ln1.bias'] = hf_weights[f'{prefix}{i}.attention.output.LayerNorm.bias']
# Attention weights
block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.query.weight']
block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.key.weight']
block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.value.weight']
block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.attention.output.dense.weight']
# Second layer norm
block['ln2.weight'] = hf_weights[f'{prefix}{i}.output.LayerNorm.weight']
block['ln2.bias'] = hf_weights[f'{prefix}{i}.output.LayerNorm.bias']
# Feed-forward weights
block['ff1.weight'] = hf_weights[f'{prefix}{i}.intermediate.dense.weight']
block['ff1.bias'] = hf_weights[f'{prefix}{i}.intermediate.dense.bias']
block['ff2.weight'] = hf_weights[f'{prefix}{i}.output.dense.weight']
block['ff2.bias'] = hf_weights[f'{prefix}{i}.output.dense.bias']
# Add position embeddings if available
if i == 0 and 'bert.embeddings.position_embeddings.weight' in hf_weights:
block['position_embeddings'] = hf_weights['bert.embeddings.position_embeddings.weight']
block_weights_list.append(block)
except KeyError as e:
logger.error(f"Failed to map BERT weights: {str(e)}")
raise ValueError(f"Missing required weight: {str(e)}")
return block_weights_list
def _map_t5_weights(
self,
hf_weights: Dict[str, np.ndarray],
config: ModelConfig,
prefix: str = 'encoder.block.'
) -> List[Dict[str, np.ndarray]]:
"""Map T5 weights with optimizations"""
block_weights_list = []
try:
for i in range(config.num_layers):
block = {}
# Layer normalization
block['ln1.weight'] = hf_weights[f'{prefix}{i}.layer.0.layer_norm.weight']
# Attention weights
block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.q.weight']
block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.k.weight']
block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.v.weight']
block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.o.weight']
# Second layer norm
block['ln2.weight'] = hf_weights[f'{prefix}{i}.layer.1.layer_norm.weight']
# Feed-forward weights
block['ff1.weight'] = hf_weights[f'{prefix}{i}.layer.1.DenseReluDense.wi.weight']
block['ff2.weight'] = hf_weights[f'{prefix}{i}.layer.1.DenseReluDense.wo.weight']
# Relative position bias if available
if f'{prefix}{i}.layer.0.SelfAttention.relative_attention_bias' in hf_weights:
block['attn.relative_attention_bias'] = hf_weights[
f'{prefix}{i}.layer.0.SelfAttention.relative_attention_bias'
]
block_weights_list.append(block)
except KeyError as e:
logger.error(f"Failed to map T5 weights: {str(e)}")
raise ValueError(f"Missing required weight: {str(e)}")
return block_weights_list
def _map_llama_weights(
self,
hf_weights: Dict[str, np.ndarray],
config: ModelConfig,
prefix: str = 'model.layers.'
) -> List[Dict[str, np.ndarray]]:
"""Map LLaMA weights with optimizations"""
block_weights_list = []
try:
for i in range(config.num_layers):
block = {}
# Attention weights
block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.q_proj.weight']
block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.k_proj.weight']
block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.v_proj.weight']
block['attn.o_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.o_proj.weight']
# Rotary embeddings
if config.rotary_dim:
block['attn.rotary_emb.inv_freq'] = hf_weights.get(
f'{prefix}{i}.self_attn.rotary_emb.inv_freq'
)
# RMSNorm weights
block['input_layernorm.weight'] = hf_weights[f'{prefix}{i}.input_layernorm.weight']
block['post_attention_layernorm.weight'] = hf_weights[f'{prefix}{i}.post_attention_layernorm.weight']
# Feed-forward weights
block['mlp.gate_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.gate_proj.weight']
block['mlp.up_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.up_proj.weight']
block['mlp.down_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.down_proj.weight']
block_weights_list.append(block)
except KeyError as e:
logger.error(f"Failed to map LLaMA weights: {str(e)}")
raise ValueError(f"Missing required weight: {str(e)}")
return block_weights_list
class ConfigParser:
"""Enhanced configuration parser with caching"""
def __init__(self, use_cache: bool = True):
self.cache_manager = CacheManager() if use_cache else None
def parse_config(
self,
config: Dict[str, Any],
model_type: Optional[ModelType] = None
) -> ModelConfig:
"""
Parse model configuration with caching
Args:
config: Configuration dictionary
model_type: Optional model type override
Returns:
ModelConfig instance
"""
if self.cache_manager:
cache_key = self.cache_manager._compute_key(config, "config_parsing")
cached_config = self.cache_manager.get(cache_key)
if cached_config is not None:
return cached_config
# Detect model type if not provided
if not model_type:
model_type = self._detect_model_type(config)
# Parse based on model type
parsed_config = self._parse_by_type(config, model_type)
if self.cache_manager:
self.cache_manager.set(
cache_key,
parsed_config,
{'model_type': model_type.value}
)
return parsed_config
def _detect_model_type(self, config: Dict[str, Any]) -> ModelType:
"""Detect model type from config"""
if 'n_layer' in config:
return ModelType.GPT2
elif 'num_hidden_layers' in config:
return ModelType.BERT
elif 'd_model' in config:
return ModelType.T5
elif 'num_key_value_heads' in config:
return ModelType.LLAMA
elif 'sliding_window' in config:
return ModelType.MISTRAL
elif 'multi_query_group_num' in config:
return ModelType.FALCON
else:
raise ValueError("Unable to detect model type from config")
def _parse_by_type(
self,
config: Dict[str, Any],
model_type: ModelType
) -> ModelConfig:
"""Parse config based on model type"""
if model_type == ModelType.GPT2:
return ModelConfig(
model_type=model_type,
num_layers=config['n_layer'],
num_heads=config['n_head'],
hidden_dim=config['n_embd'],
vocab_size=config['vocab_size'],
max_seq_len=config.get('n_positions', 1024),
intermediate_size=config.get('n_inner', None),
layer_norm_epsilon=config.get('layer_norm_epsilon', 1e-5)
)
elif model_type == ModelType.BERT:
return ModelConfig(
model_type=model_type,
num_layers=config['num_hidden_layers'],
num_heads=config['num_attention_heads'],
hidden_dim=config['hidden_size'],
vocab_size=config['vocab_size'],
max_seq_len=config.get('max_position_embeddings', 512),
intermediate_size=config.get('intermediate_size', None),
layer_norm_epsilon=config.get('layer_norm_eps', 1e-12)
)
elif model_type == ModelType.T5:
return ModelConfig(
model_type=model_type,
num_layers=config['num_layers'],
num_heads=config['num_heads'],
hidden_dim=config['d_model'],
vocab_size=config['vocab_size'],
max_seq_len=config.get('n_positions', 512),
intermediate_size=config.get('d_ff', None),
layer_norm_epsilon=config.get('layer_norm_epsilon', 1e-6)
)
elif model_type == ModelType.LLAMA:
return ModelConfig(
model_type=model_type,
num_layers=config['num_hidden_layers'],
num_heads=config['num_attention_heads'],
hidden_dim=config['hidden_size'],
vocab_size=config['vocab_size'],
max_seq_len=config.get('max_position_embeddings', 2048),
intermediate_size=config.get('intermediate_size', None),
layer_norm_epsilon=config.get('rms_norm_eps', 1e-6),
rotary_dim=config.get('rotary_dim', None)
)
else:
raise ValueError(f"Unsupported model type: {model_type}")
# Legacy function for backward compatibility
def parse_hf_config(config: Dict[str, Any]) -> Dict[str, Any]:
"""Legacy config parsing interface"""
warnings.warn(
"parse_hf_config is deprecated, use ConfigParser class instead",
DeprecationWarning
)
parser = ConfigParser(use_cache=True)
parsed_config = parser.parse_config(config)
return {
'num_layers': parsed_config.num_layers,
'num_heads': parsed_config.num_heads,
'hidden_dim': parsed_config.hidden_dim,
'vocab_size': parsed_config.vocab_size,
'max_seq_len': parsed_config.max_seq_len
}