|
|
from __future__ import annotations
|
|
|
from typing import Dict, List, Union, Optional, Any, Tuple
|
|
|
import numpy as np
|
|
|
from dataclasses import dataclass
|
|
|
from enum import Enum
|
|
|
import warnings
|
|
|
import json
|
|
|
import hashlib
|
|
|
import logging
|
|
|
from functools import lru_cache
|
|
|
from pathlib import Path
|
|
|
|
|
|
|
|
|
from .broadcast import ModalityType
|
|
|
from .attention_utils import AttentionState
|
|
|
from .core.db_manager import HeliumDBManager
|
|
|
from .virtual_gpu_device import VirtualGPUDevice
|
|
|
|
|
|
|
|
|
_gpu_devices: Dict[str, VirtualGPUDevice] = {}
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelType(Enum):
|
|
|
"""Supported model architectures"""
|
|
|
GPT2 = "gpt2"
|
|
|
BERT = "bert"
|
|
|
T5 = "t5"
|
|
|
LLAMA = "llama"
|
|
|
MISTRAL = "mistral"
|
|
|
FALCON = "falcon"
|
|
|
|
|
|
@dataclass
|
|
|
class ModelConfig:
|
|
|
"""Universal configuration for transformer models"""
|
|
|
model_type: ModelType
|
|
|
num_layers: int
|
|
|
num_heads: int
|
|
|
hidden_dim: int
|
|
|
vocab_size: int
|
|
|
max_seq_len: int
|
|
|
intermediate_size: Optional[int] = None
|
|
|
layer_norm_epsilon: float = 1e-5
|
|
|
initializer_range: float = 0.02
|
|
|
use_cache: bool = True
|
|
|
use_fp16: bool = False
|
|
|
rotary_dim: Optional[int] = None
|
|
|
vocab_padding_size: Optional[int] = None
|
|
|
|
|
|
class CacheManager:
|
|
|
"""Manages caching for model utilities"""
|
|
|
def __init__(self):
|
|
|
self.db = HeliumDBManager.get_instance()
|
|
|
|
|
|
def _compute_key(self, data: Any, prefix: str) -> str:
|
|
|
"""Compute cache key for data"""
|
|
|
if isinstance(data, np.ndarray):
|
|
|
return f"{prefix}_{hashlib.sha256(data.tobytes()).hexdigest()}"
|
|
|
return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}"
|
|
|
|
|
|
def get(self, key: str) -> Optional[Any]:
|
|
|
"""Get cached data"""
|
|
|
return self.db.get_activation(key)
|
|
|
|
|
|
def set(self, key: str, value: Any, metadata: Dict):
|
|
|
"""Cache data"""
|
|
|
self.db.set_activation(key, value, metadata)
|
|
|
|
|
|
class MaskGenerator:
|
|
|
"""Optimized attention mask generator"""
|
|
|
def __init__(self, use_cache: bool = True):
|
|
|
self.cache_manager = CacheManager() if use_cache else None
|
|
|
|
|
|
@lru_cache(maxsize=128)
|
|
|
def create_causal_mask(self, seq_len: int, dtype: np.dtype = np.bool_, device = None) -> np.ndarray:
|
|
|
"""Create causal (autoregressive) attention mask.
|
|
|
Uses caching for common sequence lengths.
|
|
|
"""
|
|
|
|
|
|
if self.cache_manager:
|
|
|
cache_key = self._compute_key((seq_len, str(dtype)), "causal_mask")
|
|
|
cached_mask = self.cache_manager.get(cache_key)
|
|
|
if cached_mask is not None:
|
|
|
return cached_mask
|
|
|
|
|
|
|
|
|
mask = np.tril(np.ones((seq_len, seq_len), dtype=dtype))
|
|
|
mask = mask[np.newaxis, np.newaxis, :, :]
|
|
|
|
|
|
|
|
|
if self.cache_manager:
|
|
|
metadata = {"type": "causal_mask", "seq_len": seq_len, "dtype": str(dtype)}
|
|
|
self.cache_manager.set(cache_key, mask, metadata)
|
|
|
|
|
|
|
|
|
return mask if device is None else device.to_gpu(mask)
|
|
|
|
|
|
def _compute_key(self, data: Any, prefix: str) -> str:
|
|
|
"""Compute cache key for data"""
|
|
|
if isinstance(data, tuple):
|
|
|
data = "_".join(str(x) for x in data)
|
|
|
return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}"
|
|
|
|
|
|
def split_heads(x: Union[str, "HeliumTensor"],
|
|
|
num_heads: int,
|
|
|
driver,
|
|
|
modality: Optional[ModalityType] = None) -> Union[str, "HeliumTensor"]:
|
|
|
"""Split hidden dim into multiple heads"""
|
|
|
if isinstance(x, str):
|
|
|
x = driver.get_tensor(x)
|
|
|
|
|
|
batch_size, seq_len, hidden_dim = x.shape
|
|
|
head_dim = hidden_dim // num_heads
|
|
|
|
|
|
|
|
|
x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim))
|
|
|
x = driver.transpose(x, (0, 2, 1, 3))
|
|
|
|
|
|
return x
|
|
|
|
|
|
def apply_rotary_embedding(x: Union[str, HeliumTensor],
|
|
|
seq_len: int,
|
|
|
head_dim: int,
|
|
|
driver) -> Union[str, HeliumTensor]:
|
|
|
"""Apply rotary positional embeddings"""
|
|
|
if isinstance(x, str):
|
|
|
x = driver.get_tensor(x)
|
|
|
|
|
|
|
|
|
pos = np.arange(seq_len)
|
|
|
|
|
|
|
|
|
freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim)
|
|
|
angles = pos[:, None] * freqs[None, :]
|
|
|
|
|
|
|
|
|
cos = np.cos(angles).reshape(seq_len, -1)
|
|
|
sin = np.sin(angles).reshape(seq_len, -1)
|
|
|
|
|
|
|
|
|
cos = driver.to_gpu(cos)
|
|
|
sin = driver.to_gpu(sin)
|
|
|
|
|
|
|
|
|
x_rot = driver.matmul(x, cos) - driver.matmul(x, sin)
|
|
|
x = driver.add(x, x_rot)
|
|
|
|
|
|
return x
|
|
|
|
|
|
|
|
|
def fuse_cross_modal_attention(q: Union[str, HeliumTensor],
|
|
|
k: Union[str, HeliumTensor],
|
|
|
v: Union[str, HeliumTensor],
|
|
|
q_modality: ModalityType,
|
|
|
kv_modality: ModalityType,
|
|
|
fusion_type: str,
|
|
|
driver,
|
|
|
state: AttentionState) -> Tuple[Union[str, HeliumTensor],
|
|
|
Union[str, HeliumTensor],
|
|
|
Union[str, HeliumTensor]]:
|
|
|
"""Fuse cross-modal attention patterns"""
|
|
|
if isinstance(driver, str):
|
|
|
driver = get_gpu_device(driver)
|
|
|
|
|
|
if fusion_type == "additive":
|
|
|
|
|
|
q = driver.add(q, k)
|
|
|
k = q
|
|
|
elif fusion_type == "multiplicative":
|
|
|
|
|
|
q = driver.mul(q, k)
|
|
|
k = q
|
|
|
elif fusion_type == "gated":
|
|
|
|
|
|
gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None)))
|
|
|
q = driver.add(
|
|
|
driver.mul(gate, q),
|
|
|
driver.mul(driver.sub(1.0, gate), k)
|
|
|
)
|
|
|
k = q
|
|
|
|
|
|
return q, k, v
|
|
|
import numpy as np
|
|
|
from enum import Enum
|
|
|
from dataclasses import dataclass
|
|
|
import warnings
|
|
|
from .core.db_manager import HeliumDBManager
|
|
|
import json
|
|
|
import hashlib
|
|
|
from pathlib import Path
|
|
|
import torch
|
|
|
import logging
|
|
|
from functools import lru_cache
|
|
|
|
|
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
if TYPE_CHECKING:
|
|
|
from .tensor import HeliumTensor
|
|
|
from .broadcast import ModalityType
|
|
|
from .attention_utils import AttentionState
|
|
|
|
|
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class ModelType(Enum):
|
|
|
"""Supported model architectures"""
|
|
|
GPT2 = "gpt2"
|
|
|
BERT = "bert"
|
|
|
T5 = "t5"
|
|
|
LLAMA = "llama"
|
|
|
MISTRAL = "mistral"
|
|
|
FALCON = "falcon"
|
|
|
|
|
|
@dataclass
|
|
|
class ModelConfig:
|
|
|
"""Universal configuration for transformer models"""
|
|
|
model_type: ModelType
|
|
|
num_layers: int
|
|
|
num_heads: int
|
|
|
hidden_dim: int
|
|
|
vocab_size: int
|
|
|
max_seq_len: int
|
|
|
intermediate_size: Optional[int] = None
|
|
|
layer_norm_epsilon: float = 1e-5
|
|
|
initializer_range: float = 0.02
|
|
|
use_cache: bool = True
|
|
|
use_fp16: bool = False
|
|
|
rotary_dim: Optional[int] = None
|
|
|
vocab_padding_size: Optional[int] = None
|
|
|
|
|
|
class CacheManager:
|
|
|
"""Manages caching for model utilities"""
|
|
|
def __init__(self):
|
|
|
self.db = HeliumDBManager.get_instance()
|
|
|
|
|
|
def _compute_key(self, data: Any, prefix: str) -> str:
|
|
|
"""Compute cache key for data"""
|
|
|
if isinstance(data, np.ndarray):
|
|
|
return f"{prefix}_{hashlib.sha256(data.tobytes()).hexdigest()}"
|
|
|
return f"{prefix}_{hashlib.sha256(str(data).encode()).hexdigest()}"
|
|
|
|
|
|
def get(self, key: str) -> Optional[Any]:
|
|
|
"""Get cached data"""
|
|
|
return self.db.get_activation(key)
|
|
|
|
|
|
def set(self, key: str, value: Any, metadata: Dict):
|
|
|
"""Cache data"""
|
|
|
self.db.set_activation(key, value, metadata)
|
|
|
|
|
|
class MaskGenerator:
|
|
|
"""Optimized attention mask generator"""
|
|
|
def __init__(self, use_cache: bool = True):
|
|
|
self.cache_manager = CacheManager() if use_cache else None
|
|
|
|
|
|
@lru_cache(maxsize=128)
|
|
|
def create_causal_mask(self, seq_len: int, dtype: np.dtype = np.bool_, device = None) -> np.ndarray:
|
|
|
"""Create causal (autoregressive) attention mask.
|
|
|
Uses caching for common sequence lengths.
|
|
|
|
|
|
Args:
|
|
|
seq_len: Sequence length
|
|
|
dtype: Data type for mask
|
|
|
device: Device to place mask on
|
|
|
|
|
|
Returns:
|
|
|
mask: Shape (1, 1, seq_len, seq_len) attention mask
|
|
|
"""
|
|
|
if self.cache_manager:
|
|
|
cache_key = self.cache_manager._compute_key((seq_len, str(dtype)), "causal_mask")
|
|
|
cached_mask = self.cache_manager.get(cache_key)
|
|
|
if cached_mask is not None:
|
|
|
return cached_mask
|
|
|
|
|
|
mask = np.tril(np.ones((seq_len, seq_len), dtype=dtype))
|
|
|
mask = mask[np.newaxis, np.newaxis, :, :]
|
|
|
|
|
|
if self.cache_manager:
|
|
|
metadata = {
|
|
|
"type": "causal_mask",
|
|
|
"seq_len": seq_len,
|
|
|
"dtype": str(dtype)
|
|
|
}
|
|
|
self.cache_manager.set(cache_key, mask, metadata)
|
|
|
|
|
|
return mask if device is None else device.to_gpu(mask)
|
|
|
|
|
|
def split_heads(
|
|
|
x: Union[str, "HeliumTensor"],
|
|
|
num_heads: int,
|
|
|
driver,
|
|
|
modality: Optional["ModalityType"] = None
|
|
|
) -> Union[str, "HeliumTensor"]:
|
|
|
"""Split hidden dim into multiple heads"""
|
|
|
if isinstance(x, str):
|
|
|
x = driver.get_tensor(x)
|
|
|
|
|
|
batch_size, seq_len, hidden_dim = x.shape
|
|
|
head_dim = hidden_dim // num_heads
|
|
|
|
|
|
|
|
|
x = driver.reshape(x, (batch_size, seq_len, num_heads, head_dim))
|
|
|
x = driver.transpose(x, (0, 2, 1, 3))
|
|
|
|
|
|
return x
|
|
|
|
|
|
def apply_rotary_embedding(
|
|
|
x: Union[str, "HeliumTensor"],
|
|
|
seq_len: int,
|
|
|
head_dim: int,
|
|
|
driver
|
|
|
) -> Union[str, "HeliumTensor"]:
|
|
|
"""Apply rotary positional embeddings"""
|
|
|
if isinstance(x, str):
|
|
|
x = driver.get_tensor(x)
|
|
|
|
|
|
|
|
|
pos = np.arange(seq_len)
|
|
|
|
|
|
|
|
|
freqs = np.exp(-np.arange(0, head_dim, 2) * np.log(10000) / head_dim)
|
|
|
angles = pos[:, None] * freqs[None, :]
|
|
|
|
|
|
|
|
|
cos = np.cos(angles).reshape(seq_len, -1)
|
|
|
sin = np.sin(angles).reshape(seq_len, -1)
|
|
|
|
|
|
|
|
|
cos = driver.to_gpu(cos)
|
|
|
sin = driver.to_gpu(sin)
|
|
|
|
|
|
|
|
|
x_rot = driver.matmul(x, cos) - driver.matmul(x, sin)
|
|
|
x = driver.add(x, x_rot)
|
|
|
|
|
|
return x
|
|
|
|
|
|
def fuse_cross_modal_attention(
|
|
|
q: Union[str, "HeliumTensor"],
|
|
|
k: Union[str, "HeliumTensor"],
|
|
|
v: Union[str, "HeliumTensor"],
|
|
|
q_modality: "ModalityType",
|
|
|
kv_modality: "ModalityType",
|
|
|
fusion_type: str,
|
|
|
driver,
|
|
|
state: "AttentionState"
|
|
|
) -> Tuple[Union[str, "HeliumTensor"], Union[str, "HeliumTensor"], Union[str, "HeliumTensor"]]:
|
|
|
"""Fuse cross-modal attention patterns"""
|
|
|
if fusion_type == "additive":
|
|
|
|
|
|
q = driver.add(q, k)
|
|
|
k = q
|
|
|
elif fusion_type == "multiplicative":
|
|
|
|
|
|
q = driver.mul(q, k)
|
|
|
k = q
|
|
|
elif fusion_type == "gated":
|
|
|
|
|
|
gate = driver.sigmoid(driver.matmul(q, state.stored_tensors.get("gate_weight", None)))
|
|
|
q = driver.add(
|
|
|
driver.mul(gate, q),
|
|
|
driver.mul(driver.sub(1.0, gate), k)
|
|
|
)
|
|
|
k = q
|
|
|
|
|
|
return q, k, v
|
|
|
|
|
|
class WeightMapper:
|
|
|
"""Optimized weight mapping utility for different model architectures"""
|
|
|
|
|
|
def __init__(self, use_cache: bool = True):
|
|
|
self.cache_manager = CacheManager() if use_cache else None
|
|
|
|
|
|
def map_weights(
|
|
|
self,
|
|
|
hf_weights: Dict[str, np.ndarray],
|
|
|
config: ModelConfig
|
|
|
) -> List[Dict[str, np.ndarray]]:
|
|
|
"""
|
|
|
Map weights based on model type
|
|
|
|
|
|
Args:
|
|
|
hf_weights: HuggingFace weight dictionary
|
|
|
config: Model configuration
|
|
|
|
|
|
Returns:
|
|
|
List of block weight dictionaries
|
|
|
"""
|
|
|
if self.cache_manager:
|
|
|
cache_key = self.cache_manager._compute_key(
|
|
|
(list(hf_weights.keys()), config.model_type.value),
|
|
|
"weight_mapping"
|
|
|
)
|
|
|
cached_mapping = self.cache_manager.get(cache_key)
|
|
|
if cached_mapping is not None:
|
|
|
return cached_mapping
|
|
|
|
|
|
mapping_funcs = {
|
|
|
ModelType.GPT2: self._map_gpt2_weights,
|
|
|
ModelType.BERT: self._map_bert_weights,
|
|
|
ModelType.T5: self._map_t5_weights,
|
|
|
ModelType.LLAMA: self._map_llama_weights,
|
|
|
ModelType.MISTRAL: self._map_mistral_weights,
|
|
|
ModelType.FALCON: self._map_falcon_weights
|
|
|
}
|
|
|
|
|
|
mapper = mapping_funcs.get(config.model_type)
|
|
|
if not mapper:
|
|
|
raise ValueError(f"Unsupported model type: {config.model_type}")
|
|
|
|
|
|
result = mapper(hf_weights, config)
|
|
|
|
|
|
if self.cache_manager:
|
|
|
self.cache_manager.set(
|
|
|
cache_key,
|
|
|
result,
|
|
|
{'model_type': config.model_type.value}
|
|
|
)
|
|
|
|
|
|
return result
|
|
|
|
|
|
def _map_gpt2_weights(
|
|
|
self,
|
|
|
hf_weights: Dict[str, np.ndarray],
|
|
|
config: ModelConfig,
|
|
|
prefix: str = 'transformer.h.'
|
|
|
) -> List[Dict[str, np.ndarray]]:
|
|
|
"""Map GPT-2 weights with optimizations"""
|
|
|
block_weights_list = []
|
|
|
|
|
|
try:
|
|
|
for i in range(config.num_layers):
|
|
|
block = {}
|
|
|
|
|
|
block['ln1.weight'] = hf_weights[f'{prefix}{i}.ln_1.weight']
|
|
|
block['ln1.bias'] = hf_weights[f'{prefix}{i}.ln_1.bias']
|
|
|
|
|
|
|
|
|
attn_weight = hf_weights[f'{prefix}{i}.attn.c_attn.weight']
|
|
|
split_size = attn_weight.shape[0] // 3
|
|
|
block['attn.q_proj.weight'] = attn_weight[:, :split_size]
|
|
|
block['attn.k_proj.weight'] = attn_weight[:, split_size:2*split_size]
|
|
|
block['attn.v_proj.weight'] = attn_weight[:, 2*split_size:]
|
|
|
|
|
|
|
|
|
block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.attn.c_proj.weight']
|
|
|
|
|
|
|
|
|
block['ln2.weight'] = hf_weights[f'{prefix}{i}.ln_2.weight']
|
|
|
block['ln2.bias'] = hf_weights[f'{prefix}{i}.ln_2.bias']
|
|
|
|
|
|
|
|
|
block['ff1.weight'] = hf_weights[f'{prefix}{i}.mlp.c_fc.weight']
|
|
|
block['ff1.bias'] = hf_weights[f'{prefix}{i}.mlp.c_fc.bias']
|
|
|
block['ff2.weight'] = hf_weights[f'{prefix}{i}.mlp.c_proj.weight']
|
|
|
block['ff2.bias'] = hf_weights[f'{prefix}{i}.mlp.c_proj.bias']
|
|
|
|
|
|
|
|
|
if config.rotary_dim:
|
|
|
if f'{prefix}{i}.attn.rotary_emb.inv_freq' in hf_weights:
|
|
|
block['attn.rotary_emb.inv_freq'] = hf_weights[f'{prefix}{i}.attn.rotary_emb.inv_freq']
|
|
|
|
|
|
block_weights_list.append(block)
|
|
|
|
|
|
except KeyError as e:
|
|
|
logger.error(f"Failed to map GPT-2 weights: {str(e)}")
|
|
|
raise ValueError(f"Missing required weight: {str(e)}")
|
|
|
|
|
|
return block_weights_list
|
|
|
|
|
|
def _map_bert_weights(
|
|
|
self,
|
|
|
hf_weights: Dict[str, np.ndarray],
|
|
|
config: ModelConfig,
|
|
|
prefix: str = 'bert.encoder.layer.'
|
|
|
) -> List[Dict[str, np.ndarray]]:
|
|
|
"""Map BERT weights with optimizations"""
|
|
|
block_weights_list = []
|
|
|
|
|
|
try:
|
|
|
for i in range(config.num_layers):
|
|
|
block = {}
|
|
|
|
|
|
block['ln1.weight'] = hf_weights[f'{prefix}{i}.attention.output.LayerNorm.weight']
|
|
|
block['ln1.bias'] = hf_weights[f'{prefix}{i}.attention.output.LayerNorm.bias']
|
|
|
|
|
|
|
|
|
block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.query.weight']
|
|
|
block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.key.weight']
|
|
|
block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.attention.self.value.weight']
|
|
|
block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.attention.output.dense.weight']
|
|
|
|
|
|
|
|
|
block['ln2.weight'] = hf_weights[f'{prefix}{i}.output.LayerNorm.weight']
|
|
|
block['ln2.bias'] = hf_weights[f'{prefix}{i}.output.LayerNorm.bias']
|
|
|
|
|
|
|
|
|
block['ff1.weight'] = hf_weights[f'{prefix}{i}.intermediate.dense.weight']
|
|
|
block['ff1.bias'] = hf_weights[f'{prefix}{i}.intermediate.dense.bias']
|
|
|
block['ff2.weight'] = hf_weights[f'{prefix}{i}.output.dense.weight']
|
|
|
block['ff2.bias'] = hf_weights[f'{prefix}{i}.output.dense.bias']
|
|
|
|
|
|
|
|
|
if i == 0 and 'bert.embeddings.position_embeddings.weight' in hf_weights:
|
|
|
block['position_embeddings'] = hf_weights['bert.embeddings.position_embeddings.weight']
|
|
|
|
|
|
block_weights_list.append(block)
|
|
|
|
|
|
except KeyError as e:
|
|
|
logger.error(f"Failed to map BERT weights: {str(e)}")
|
|
|
raise ValueError(f"Missing required weight: {str(e)}")
|
|
|
|
|
|
return block_weights_list
|
|
|
|
|
|
|
|
|
def _map_t5_weights(
|
|
|
self,
|
|
|
hf_weights: Dict[str, np.ndarray],
|
|
|
config: ModelConfig,
|
|
|
prefix: str = 'encoder.block.'
|
|
|
) -> List[Dict[str, np.ndarray]]:
|
|
|
"""Map T5 weights with optimizations"""
|
|
|
block_weights_list = []
|
|
|
|
|
|
try:
|
|
|
for i in range(config.num_layers):
|
|
|
block = {}
|
|
|
|
|
|
block['ln1.weight'] = hf_weights[f'{prefix}{i}.layer.0.layer_norm.weight']
|
|
|
|
|
|
|
|
|
block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.q.weight']
|
|
|
block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.k.weight']
|
|
|
block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.v.weight']
|
|
|
block['attn.out_proj.weight'] = hf_weights[f'{prefix}{i}.layer.0.SelfAttention.o.weight']
|
|
|
|
|
|
|
|
|
block['ln2.weight'] = hf_weights[f'{prefix}{i}.layer.1.layer_norm.weight']
|
|
|
|
|
|
|
|
|
block['ff1.weight'] = hf_weights[f'{prefix}{i}.layer.1.DenseReluDense.wi.weight']
|
|
|
block['ff2.weight'] = hf_weights[f'{prefix}{i}.layer.1.DenseReluDense.wo.weight']
|
|
|
|
|
|
|
|
|
if f'{prefix}{i}.layer.0.SelfAttention.relative_attention_bias' in hf_weights:
|
|
|
block['attn.relative_attention_bias'] = hf_weights[
|
|
|
f'{prefix}{i}.layer.0.SelfAttention.relative_attention_bias'
|
|
|
]
|
|
|
|
|
|
block_weights_list.append(block)
|
|
|
|
|
|
except KeyError as e:
|
|
|
logger.error(f"Failed to map T5 weights: {str(e)}")
|
|
|
raise ValueError(f"Missing required weight: {str(e)}")
|
|
|
|
|
|
return block_weights_list
|
|
|
|
|
|
def _map_llama_weights(
|
|
|
self,
|
|
|
hf_weights: Dict[str, np.ndarray],
|
|
|
config: ModelConfig,
|
|
|
prefix: str = 'model.layers.'
|
|
|
) -> List[Dict[str, np.ndarray]]:
|
|
|
"""Map LLaMA weights with optimizations"""
|
|
|
block_weights_list = []
|
|
|
|
|
|
try:
|
|
|
for i in range(config.num_layers):
|
|
|
block = {}
|
|
|
|
|
|
block['attn.q_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.q_proj.weight']
|
|
|
block['attn.k_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.k_proj.weight']
|
|
|
block['attn.v_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.v_proj.weight']
|
|
|
block['attn.o_proj.weight'] = hf_weights[f'{prefix}{i}.self_attn.o_proj.weight']
|
|
|
|
|
|
|
|
|
if config.rotary_dim:
|
|
|
block['attn.rotary_emb.inv_freq'] = hf_weights.get(
|
|
|
f'{prefix}{i}.self_attn.rotary_emb.inv_freq'
|
|
|
)
|
|
|
|
|
|
|
|
|
block['input_layernorm.weight'] = hf_weights[f'{prefix}{i}.input_layernorm.weight']
|
|
|
block['post_attention_layernorm.weight'] = hf_weights[f'{prefix}{i}.post_attention_layernorm.weight']
|
|
|
|
|
|
|
|
|
block['mlp.gate_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.gate_proj.weight']
|
|
|
block['mlp.up_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.up_proj.weight']
|
|
|
block['mlp.down_proj.weight'] = hf_weights[f'{prefix}{i}.mlp.down_proj.weight']
|
|
|
|
|
|
block_weights_list.append(block)
|
|
|
|
|
|
except KeyError as e:
|
|
|
logger.error(f"Failed to map LLaMA weights: {str(e)}")
|
|
|
raise ValueError(f"Missing required weight: {str(e)}")
|
|
|
|
|
|
return block_weights_list
|
|
|
|
|
|
class ConfigParser:
|
|
|
"""Enhanced configuration parser with caching"""
|
|
|
|
|
|
def __init__(self, use_cache: bool = True):
|
|
|
self.cache_manager = CacheManager() if use_cache else None
|
|
|
|
|
|
def parse_config(
|
|
|
self,
|
|
|
config: Dict[str, Any],
|
|
|
model_type: Optional[ModelType] = None
|
|
|
) -> ModelConfig:
|
|
|
"""
|
|
|
Parse model configuration with caching
|
|
|
|
|
|
Args:
|
|
|
config: Configuration dictionary
|
|
|
model_type: Optional model type override
|
|
|
|
|
|
Returns:
|
|
|
ModelConfig instance
|
|
|
"""
|
|
|
if self.cache_manager:
|
|
|
cache_key = self.cache_manager._compute_key(config, "config_parsing")
|
|
|
cached_config = self.cache_manager.get(cache_key)
|
|
|
if cached_config is not None:
|
|
|
return cached_config
|
|
|
|
|
|
|
|
|
if not model_type:
|
|
|
model_type = self._detect_model_type(config)
|
|
|
|
|
|
|
|
|
parsed_config = self._parse_by_type(config, model_type)
|
|
|
|
|
|
if self.cache_manager:
|
|
|
self.cache_manager.set(
|
|
|
cache_key,
|
|
|
parsed_config,
|
|
|
{'model_type': model_type.value}
|
|
|
)
|
|
|
|
|
|
return parsed_config
|
|
|
|
|
|
def _detect_model_type(self, config: Dict[str, Any]) -> ModelType:
|
|
|
"""Detect model type from config"""
|
|
|
if 'n_layer' in config:
|
|
|
return ModelType.GPT2
|
|
|
elif 'num_hidden_layers' in config:
|
|
|
return ModelType.BERT
|
|
|
elif 'd_model' in config:
|
|
|
return ModelType.T5
|
|
|
elif 'num_key_value_heads' in config:
|
|
|
return ModelType.LLAMA
|
|
|
elif 'sliding_window' in config:
|
|
|
return ModelType.MISTRAL
|
|
|
elif 'multi_query_group_num' in config:
|
|
|
return ModelType.FALCON
|
|
|
else:
|
|
|
raise ValueError("Unable to detect model type from config")
|
|
|
|
|
|
def _parse_by_type(
|
|
|
self,
|
|
|
config: Dict[str, Any],
|
|
|
model_type: ModelType
|
|
|
) -> ModelConfig:
|
|
|
"""Parse config based on model type"""
|
|
|
if model_type == ModelType.GPT2:
|
|
|
return ModelConfig(
|
|
|
model_type=model_type,
|
|
|
num_layers=config['n_layer'],
|
|
|
num_heads=config['n_head'],
|
|
|
hidden_dim=config['n_embd'],
|
|
|
vocab_size=config['vocab_size'],
|
|
|
max_seq_len=config.get('n_positions', 1024),
|
|
|
intermediate_size=config.get('n_inner', None),
|
|
|
layer_norm_epsilon=config.get('layer_norm_epsilon', 1e-5)
|
|
|
)
|
|
|
elif model_type == ModelType.BERT:
|
|
|
return ModelConfig(
|
|
|
model_type=model_type,
|
|
|
num_layers=config['num_hidden_layers'],
|
|
|
num_heads=config['num_attention_heads'],
|
|
|
hidden_dim=config['hidden_size'],
|
|
|
vocab_size=config['vocab_size'],
|
|
|
max_seq_len=config.get('max_position_embeddings', 512),
|
|
|
intermediate_size=config.get('intermediate_size', None),
|
|
|
layer_norm_epsilon=config.get('layer_norm_eps', 1e-12)
|
|
|
)
|
|
|
elif model_type == ModelType.T5:
|
|
|
return ModelConfig(
|
|
|
model_type=model_type,
|
|
|
num_layers=config['num_layers'],
|
|
|
num_heads=config['num_heads'],
|
|
|
hidden_dim=config['d_model'],
|
|
|
vocab_size=config['vocab_size'],
|
|
|
max_seq_len=config.get('n_positions', 512),
|
|
|
intermediate_size=config.get('d_ff', None),
|
|
|
layer_norm_epsilon=config.get('layer_norm_epsilon', 1e-6)
|
|
|
)
|
|
|
elif model_type == ModelType.LLAMA:
|
|
|
return ModelConfig(
|
|
|
model_type=model_type,
|
|
|
num_layers=config['num_hidden_layers'],
|
|
|
num_heads=config['num_attention_heads'],
|
|
|
hidden_dim=config['hidden_size'],
|
|
|
vocab_size=config['vocab_size'],
|
|
|
max_seq_len=config.get('max_position_embeddings', 2048),
|
|
|
intermediate_size=config.get('intermediate_size', None),
|
|
|
layer_norm_epsilon=config.get('rms_norm_eps', 1e-6),
|
|
|
rotary_dim=config.get('rotary_dim', None)
|
|
|
)
|
|
|
else:
|
|
|
raise ValueError(f"Unsupported model type: {model_type}")
|
|
|
|
|
|
|
|
|
def parse_hf_config(config: Dict[str, Any]) -> Dict[str, Any]:
|
|
|
"""Legacy config parsing interface"""
|
|
|
warnings.warn(
|
|
|
"parse_hf_config is deprecated, use ConfigParser class instead",
|
|
|
DeprecationWarning
|
|
|
)
|
|
|
|
|
|
parser = ConfigParser(use_cache=True)
|
|
|
parsed_config = parser.parse_config(config)
|
|
|
|
|
|
return {
|
|
|
'num_layers': parsed_config.num_layers,
|
|
|
'num_heads': parsed_config.num_heads,
|
|
|
'hidden_dim': parsed_config.hidden_dim,
|
|
|
'vocab_size': parsed_config.vocab_size,
|
|
|
'max_seq_len': parsed_config.max_seq_len
|
|
|
}
|
|
|
|