|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
|
from pathlib import Path |
|
|
from typing import Optional |
|
|
|
|
|
import jax |
|
|
import torch |
|
|
from jax import dlpack as jax_dlpack |
|
|
from torch.utils import dlpack as torch_dlpack |
|
|
|
|
|
from ..._utils import pad_vocab_size, release_gc |
|
|
from ...layers import MoeConfig |
|
|
from ...logger import logger |
|
|
from ...quantization import QuantAlgo |
|
|
from ..modeling_utils import PretrainedConfig, QuantConfig, optimize_model |
|
|
|
|
|
|
|
|
def split(v, tp_size, idx, dim=0): |
|
|
if tp_size == 1: |
|
|
return v |
|
|
if len(v.shape) == 1: |
|
|
return torch.chunk(v, tp_size)[idx].contiguous() |
|
|
else: |
|
|
return torch.chunk(v, tp_size, dim=dim)[idx] |
|
|
|
|
|
|
|
|
def split_matrix_tp(v, tensor_parallel, rank, dim): |
|
|
return split(v, tensor_parallel, rank, dim=dim) |
|
|
|
|
|
|
|
|
def get_weight(config, prefix, dtype, postfix='.weight'): |
|
|
if config[prefix + postfix].dtype != dtype: |
|
|
config[prefix + postfix].data = config[prefix + postfix].to(dtype) |
|
|
return config[prefix + postfix].detach().cpu() |
|
|
|
|
|
|
|
|
def get_jax_weight(config, prefix, dtype, postfix='.weight', key_name='scale'): |
|
|
return torch.as_tensor((config[prefix + postfix][key_name])._value, |
|
|
dtype=dtype).T |
|
|
|
|
|
|
|
|
def get_jax_weight_scale(params, key): |
|
|
jax_obj = params[key]['w'] |
|
|
jax_scales = jax.device_put(jax_obj.scales, device=jax.devices('cpu')[0]) |
|
|
|
|
|
torch_scales = torch_dlpack.from_dlpack(jax_dlpack.to_dlpack(jax_scales)) |
|
|
return torch.as_tensor(jax_obj.weight._value, |
|
|
dtype=torch.int8), torch_scales |
|
|
|
|
|
|
|
|
def get_tllm_linear_weight( |
|
|
weight, |
|
|
torch_weight_scales, |
|
|
prefix, |
|
|
plugin_weight_only_quant_type=torch.int8, |
|
|
postfix='weight', |
|
|
): |
|
|
results = {} |
|
|
processed_weight = torch.ops.trtllm.preprocess_weights_for_mixed_gemm( |
|
|
weight.contiguous(), plugin_weight_only_quant_type, torch.bfloat16) |
|
|
results[prefix + postfix] = processed_weight |
|
|
|
|
|
results[prefix + 'per_channel_scale'] = torch_weight_scales.contiguous() |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def convert_grok(hf_model, |
|
|
config, |
|
|
mapping, |
|
|
vocab_size=32000, |
|
|
dtype='float32', |
|
|
use_parallel_embedding=False, |
|
|
sharding_dim=0, |
|
|
use_weight_only=False, |
|
|
share_embedding_table=False, |
|
|
use_gemm_woq_plugin=False, |
|
|
plugin_weight_only_quant_type=torch.int8, |
|
|
moe_config=None): |
|
|
|
|
|
weights = {} |
|
|
tik = time.time() |
|
|
tensor_parallel = mapping.tp_size |
|
|
model_params = hf_model |
|
|
dtype = getattr(torch, dtype) |
|
|
|
|
|
num_attention_heads = config['num_attention_heads'] |
|
|
hidden_size = config['hidden_size'] |
|
|
hidden_size // num_attention_heads |
|
|
|
|
|
layers_range = mapping.pp_layers(config['num_hidden_layers']) |
|
|
|
|
|
|
|
|
|
|
|
def convert_layer(l): |
|
|
prefix = f'transformer/decoder_layer_{l}/' |
|
|
print(prefix) |
|
|
tllm_prex = f'transformer.layers.{l - layers_range[0]}.' |
|
|
|
|
|
q_weight, q_scale = get_jax_weight_scale( |
|
|
model_params, prefix + 'multi_head_attention/query') |
|
|
k_weight, k_scale = get_jax_weight_scale( |
|
|
model_params, prefix + 'multi_head_attention/key') |
|
|
v_weight, v_scale = get_jax_weight_scale( |
|
|
model_params, prefix + 'multi_head_attention/value') |
|
|
|
|
|
wq = split(q_weight, mapping.tp_size, mapping.tp_rank, dim=1) |
|
|
wk = split(k_weight, mapping.tp_size, mapping.tp_rank, dim=1) |
|
|
wv = split(v_weight, mapping.tp_size, mapping.tp_rank, dim=1) |
|
|
qs = split(q_scale, mapping.tp_size, mapping.tp_rank, dim=1) |
|
|
ks = split(k_scale, mapping.tp_size, mapping.tp_rank, dim=1) |
|
|
vs = split(v_scale, mapping.tp_size, mapping.tp_rank, dim=1) |
|
|
split_v = torch.concat((wq, wk, wv), dim=1) |
|
|
scale_v = torch.concat((qs, ks, vs), dim=1) |
|
|
|
|
|
weights.update( |
|
|
get_tllm_linear_weight(split_v, scale_v.squeeze(), |
|
|
tllm_prex + 'attention.qkv.', |
|
|
plugin_weight_only_quant_type)) |
|
|
|
|
|
attn_dense_weight, attn_dense_scales = get_jax_weight_scale( |
|
|
model_params, prefix + 'multi_head_attention/linear') |
|
|
|
|
|
split_v = split_matrix_tp(attn_dense_weight, |
|
|
tensor_parallel, |
|
|
mapping.tp_rank, |
|
|
dim=0) |
|
|
split_scales = split_matrix_tp(attn_dense_scales, |
|
|
tensor_parallel, |
|
|
mapping.tp_rank, |
|
|
dim=0) |
|
|
|
|
|
weights.update( |
|
|
get_tllm_linear_weight(split_v, split_scales.squeeze(), |
|
|
tllm_prex + 'attention.dense.', |
|
|
plugin_weight_only_quant_type)) |
|
|
|
|
|
w3, s3 = get_jax_weight_scale( |
|
|
model_params, f'transformer/decoder_layer_{l}/moe/linear_v') |
|
|
|
|
|
w2, s2 = get_jax_weight_scale( |
|
|
model_params, f'transformer/decoder_layer_{l}/moe/linear_1') |
|
|
|
|
|
w1, s1 = get_jax_weight_scale( |
|
|
model_params, f'transformer/decoder_layer_{l}/moe/linear') |
|
|
|
|
|
|
|
|
w3_split = split(w3, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0) |
|
|
w2_split = split(w2, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0) |
|
|
w1_split = split(w1, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0) |
|
|
|
|
|
s3_split = split(s3, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0) |
|
|
s2_split = split(s2, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0) |
|
|
s1_split = split(s1, mapping.moe_ep_size, mapping.moe_ep_rank, dim=0) |
|
|
|
|
|
w3_split = split(w3_split, |
|
|
mapping.moe_tp_size, |
|
|
mapping.moe_tp_rank, |
|
|
dim=2) |
|
|
w2_split = split(w2_split, |
|
|
mapping.moe_tp_size, |
|
|
mapping.moe_tp_rank, |
|
|
dim=1) |
|
|
w1_split = split(w1_split, |
|
|
mapping.moe_tp_size, |
|
|
mapping.moe_tp_rank, |
|
|
dim=2) |
|
|
|
|
|
s3_split = split(s3_split, |
|
|
mapping.moe_tp_size, |
|
|
mapping.moe_tp_rank, |
|
|
dim=2) |
|
|
s2_split = split(s2_split, |
|
|
mapping.moe_tp_size, |
|
|
mapping.moe_tp_rank, |
|
|
dim=1) |
|
|
s1_split = split(s1_split, |
|
|
mapping.moe_tp_size, |
|
|
mapping.moe_tp_rank, |
|
|
dim=2) |
|
|
|
|
|
weights.update( |
|
|
get_tllm_linear_weight(w2_split, |
|
|
s2_split.reshape(moe_config.num_experts, -1), |
|
|
tllm_prex + 'mlp.proj.', |
|
|
plugin_weight_only_quant_type)) |
|
|
|
|
|
weights.update( |
|
|
get_tllm_linear_weight( |
|
|
torch.concat([w3_split, w1_split], dim=-1), |
|
|
torch.concat([s3_split, s1_split], |
|
|
dim=-1).reshape(moe_config.num_experts, -1), |
|
|
tllm_prex + 'mlp.fc.', |
|
|
plugin_weight_only_quant_type, |
|
|
)) |
|
|
|
|
|
moe_experts_gate_weights = get_jax_weight(model_params, |
|
|
prefix + 'router', |
|
|
torch.float32, |
|
|
postfix='', |
|
|
key_name='w').contiguous() |
|
|
|
|
|
weights[tllm_prex + 'mlp.router.weight'] = moe_experts_gate_weights |
|
|
|
|
|
|
|
|
input_ln_weight = get_jax_weight(model_params, |
|
|
prefix + 'rms_norm', |
|
|
dtype, |
|
|
postfix='') |
|
|
weights[tllm_prex + 'input_layernorm.weight'] = input_ln_weight |
|
|
|
|
|
post_attn_weight = get_jax_weight(model_params, |
|
|
prefix + 'rms_norm_1', |
|
|
dtype, |
|
|
postfix='') |
|
|
weights[tllm_prex + 'post_attn_layernorm.weight'] = post_attn_weight |
|
|
|
|
|
post_ln_weight = get_jax_weight(model_params, |
|
|
prefix + 'rms_norm_2', |
|
|
dtype, |
|
|
postfix='') |
|
|
weights[tllm_prex + 'post_layernorm.weight'] = post_ln_weight |
|
|
|
|
|
post_mlp_weight = get_jax_weight(model_params, |
|
|
prefix + 'rms_norm_3', |
|
|
dtype, |
|
|
postfix='') |
|
|
weights[tllm_prex + 'post_mlp_layernorm.weight'] = post_mlp_weight |
|
|
|
|
|
for l in layers_range: |
|
|
convert_layer(l) |
|
|
release_gc() |
|
|
|
|
|
v = get_jax_weight(model_params, |
|
|
'language_model/in_out_embed', |
|
|
dtype, |
|
|
postfix='', |
|
|
key_name='embeddings').T |
|
|
tie_word_embeddings = config['tie_word_embeddings'] |
|
|
if tie_word_embeddings: |
|
|
|
|
|
if mapping.is_last_pp_rank(): |
|
|
if vocab_size % mapping.tp_size != 0: |
|
|
|
|
|
vocab_size_padded = pad_vocab_size(vocab_size, mapping.tp_size) |
|
|
pad_width = vocab_size_padded - vocab_size |
|
|
|
|
|
v = torch.nn.functional.pad(v, (0, pad_width, 0, 0), 'constant', |
|
|
0) |
|
|
weights['lm_head.weight'] = split(v, mapping.tp_size, |
|
|
mapping.tp_rank) |
|
|
|
|
|
if use_parallel_embedding: |
|
|
v = split_matrix_tp(v, |
|
|
mapping.tp_size, |
|
|
mapping.tp_rank, |
|
|
dim=sharding_dim) |
|
|
|
|
|
if mapping.is_first_pp_rank(): |
|
|
weights['transformer.vocab_embedding.weight'] = v |
|
|
|
|
|
ln_f_w = get_jax_weight(model_params, |
|
|
'language_model/rms_norm', |
|
|
dtype, |
|
|
postfix='') |
|
|
weights['transformer.ln_f.weight'] = ln_f_w |
|
|
|
|
|
tok = time.time() |
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) |
|
|
print(f'Weights loaded. Total time: {t}') |
|
|
return weights |
|
|
|
|
|
|
|
|
def create_config_from_xai(dtype, |
|
|
mapping, |
|
|
quantization: QuantConfig = None, |
|
|
override_fields: dict = {}): |
|
|
config = {} |
|
|
hf_config = { |
|
|
"architectures": ["Grok1ModelForCausalLM"], |
|
|
"vocab_size": 131072, |
|
|
"hidden_size": 6144, |
|
|
"intermediate_size": 32768, |
|
|
"num_hidden_layers": 64, |
|
|
"num_attention_heads": 48, |
|
|
"num_key_value_heads": 8, |
|
|
"attn_output_multiplier": 0.08838834764831845, |
|
|
"embedding_multiplier_scale": 78.38367176906169, |
|
|
"output_multiplier_scale": 0.5773502691896257, |
|
|
"max_attn_value": 30.0, |
|
|
"max_position_embeddings": 8192, |
|
|
"rms_norm_eps": 1e-5, |
|
|
"use_cache": True, |
|
|
"pad_token_id": 0, |
|
|
"bos_token_id": 1, |
|
|
"eos_token_id": 2, |
|
|
"tie_word_embeddings": True, |
|
|
"num_experts_per_tok": 2, |
|
|
"num_experts": 8, |
|
|
"output_router_logits": False, |
|
|
"router_aux_loss_coef": 0.001, |
|
|
"torch_dtype": "bfloat16", |
|
|
"transformers_version": "4.35.0" |
|
|
} |
|
|
|
|
|
n_head = hf_config['num_attention_heads'] |
|
|
inter_size = hf_config['intermediate_size'] |
|
|
n_layer = hf_config['num_hidden_layers'] |
|
|
|
|
|
n_embd = hf_config['hidden_size'] |
|
|
n_kv_head = hf_config['num_key_value_heads'] |
|
|
rms_norm_eps = hf_config['rms_norm_eps'] |
|
|
vocab_size = hf_config['vocab_size'] |
|
|
n_positions = hf_config['max_position_embeddings'] |
|
|
hidden_act = 'geglu' |
|
|
config['rotary_scaling'] = None |
|
|
rotary_base = 10000.0 |
|
|
|
|
|
config[ |
|
|
'moe_normalization_mode'] = MoeConfig.ExpertScaleNormalizationMode.NONE |
|
|
|
|
|
moe_num_experts = hf_config['num_experts'] |
|
|
|
|
|
moe_top_k = hf_config['num_experts_per_tok'] |
|
|
|
|
|
attn_output_multiplier = hf_config['attn_output_multiplier'] |
|
|
embedding_multiplier_scale = hf_config['embedding_multiplier_scale'] |
|
|
|
|
|
output_multiplier_scale = hf_config['output_multiplier_scale'] |
|
|
max_attn_value = hf_config['max_attn_value'] |
|
|
|
|
|
architecture = hf_config['architectures'][0] |
|
|
|
|
|
attn_bias = False |
|
|
|
|
|
config.update({ |
|
|
'architecture': |
|
|
architecture, |
|
|
'dtype': |
|
|
dtype, |
|
|
'logits_dtype': |
|
|
'float32', |
|
|
'num_hidden_layers': |
|
|
n_layer, |
|
|
'num_attention_heads': |
|
|
n_head, |
|
|
'hidden_size': |
|
|
n_embd, |
|
|
'intermediate_size': |
|
|
inter_size, |
|
|
'num_key_value_heads': |
|
|
n_kv_head, |
|
|
'vocab_size': |
|
|
vocab_size, |
|
|
'position_embedding_type': |
|
|
'rope_gpt_neox', |
|
|
'max_position_embeddings': |
|
|
n_positions, |
|
|
'hidden_act': |
|
|
hidden_act, |
|
|
'rotary_base': |
|
|
rotary_base, |
|
|
'norm_epsilon': |
|
|
rms_norm_eps, |
|
|
'moe_num_experts': |
|
|
moe_num_experts, |
|
|
'moe_top_k': |
|
|
moe_top_k, |
|
|
'moe_normalization_mode': |
|
|
MoeConfig.ExpertScaleNormalizationMode.NONE, |
|
|
|
|
|
'mapping': { |
|
|
'world_size': mapping.tp_size * mapping.pp_size, |
|
|
'tp_size': mapping.tp_size, |
|
|
'pp_size': mapping.pp_size, |
|
|
'moe_tp_size': mapping.moe_tp_size, |
|
|
'moe_ep_size': mapping.moe_ep_size, |
|
|
}, |
|
|
'attn_bias': |
|
|
attn_bias, |
|
|
"attn_output_multiplier": |
|
|
attn_output_multiplier, |
|
|
"embedding_multiplier_scale": |
|
|
embedding_multiplier_scale, |
|
|
"output_multiplier_scale": |
|
|
output_multiplier_scale, |
|
|
"max_attn_value": |
|
|
max_attn_value, |
|
|
"tie_word_embeddings": |
|
|
True, |
|
|
}) |
|
|
|
|
|
config['quantization'] = quantization.to_dict() |
|
|
config.update(override_fields) |
|
|
|
|
|
return config |
|
|
|
|
|
|
|
|
def from_hugging_face(cls, |
|
|
model_dir, |
|
|
dtype, |
|
|
*, |
|
|
mapping, |
|
|
quantization: QuantConfig = None, |
|
|
override_fields={}, |
|
|
skip_loading_weights=False, |
|
|
preloaded_model=None): |
|
|
''' Create a LLaMAForCausalLM object from give parameters |
|
|
''' |
|
|
assert model_dir is not None |
|
|
if isinstance(model_dir, Path): |
|
|
model_dir = str(model_dir) |
|
|
|
|
|
if override_fields.get('share_embedding_table', False): |
|
|
logger.warning( |
|
|
"Llama model does not support share_embedding_table; setting share_embedding_table=False" |
|
|
) |
|
|
override_fields['share_embedding_table'] = False |
|
|
|
|
|
config = create_config_from_xai(dtype, |
|
|
mapping, |
|
|
quantization, |
|
|
override_fields=override_fields) |
|
|
|
|
|
pretrained_config = PretrainedConfig.from_dict(config) |
|
|
pretrained_config.set_rank(mapping.rank) |
|
|
|
|
|
grok = cls.from_config(pretrained_config) |
|
|
grok = optimize_model( |
|
|
grok, |
|
|
use_parallel_embedding=pretrained_config.use_parallel_embedding, |
|
|
share_embedding_table=pretrained_config.share_embedding_table, |
|
|
) |
|
|
|
|
|
if skip_loading_weights: |
|
|
return grok |
|
|
|
|
|
weights = load_weights_from_xai(config=config, |
|
|
mapping=mapping, |
|
|
model=preloaded_model) |
|
|
|
|
|
grok.load(weights) |
|
|
return grok |
|
|
|
|
|
|
|
|
def quantize(dtype, |
|
|
model_dir, |
|
|
output_dir, |
|
|
mapping, |
|
|
quantization: QuantConfig, |
|
|
*, |
|
|
override_fields, |
|
|
dataset_cache_dir: Optional[str] = None): |
|
|
''' |
|
|
Quantize the save the model as TRT-LLM checkpoint to output_dir |
|
|
''' |
|
|
pass |
|
|
|
|
|
|
|
|
def load_weights_from_xai(*, config, mapping, model): |
|
|
assert model is not None |
|
|
plugin_weight_only_quant_type = None |
|
|
quant_algo = config['quantization']['quant_algo'] |
|
|
assert quant_algo == QuantAlgo.W8A16 |
|
|
plugin_weight_only_quant_type = torch.int8 |
|
|
|
|
|
moe_config = MoeConfig(config['moe_num_experts'], config['moe_top_k'], |
|
|
config['moe_normalization_mode']).validate() |
|
|
|
|
|
use_weight_only = quant_algo in [QuantAlgo.W8A16] |
|
|
|
|
|
weights = convert_grok( |
|
|
model, |
|
|
config, |
|
|
mapping, |
|
|
vocab_size=config['vocab_size'], |
|
|
dtype=config['dtype'], |
|
|
use_weight_only=use_weight_only, |
|
|
use_gemm_woq_plugin=not config.get('disable_weight_only_quant_plugin', |
|
|
False), |
|
|
plugin_weight_only_quant_type=plugin_weight_only_quant_type, |
|
|
use_parallel_embedding=config.get('use_parallel_embedding', False), |
|
|
sharding_dim=config.get('embedding_sharding_dim', 0), |
|
|
share_embedding_table=config.get('share_embedding_table', False), |
|
|
moe_config=moe_config) |
|
|
return weights |
|
|
|