|
|
import argparse |
|
|
import json |
|
|
import os |
|
|
import time |
|
|
|
|
|
import torch |
|
|
from safetensors.torch import save_file |
|
|
|
|
|
import tensorrt_llm |
|
|
from tensorrt_llm.functional import LayerNormPositionType, LayerNormType |
|
|
from tensorrt_llm.models.convert_utils import weight_only_quantize_dict |
|
|
from tensorrt_llm.quantization import QuantAlgo |
|
|
|
|
|
|
|
|
def parse_arguments(): |
|
|
parser = argparse.ArgumentParser() |
|
|
parser.add_argument('--model_path', type=str, required=True, |
|
|
help="Path to the FireRedASR model.pth.tar checkpoint.") |
|
|
parser.add_argument('--output_dir', type=str, default='tllm_checkpoint', |
|
|
help='The path to save the TensorRT-LLM checkpoint') |
|
|
parser.add_argument('--dtype', type=str, default='float16', |
|
|
choices=['float32', 'bfloat16', 'float16']) |
|
|
parser.add_argument('--logits_dtype', type=str, default='float16', |
|
|
choices=['float16', 'float32']) |
|
|
parser.add_argument( |
|
|
'--use_weight_only', |
|
|
default=False, |
|
|
action="store_true", |
|
|
help='Quantize weights for the various GEMMs to INT4/INT8.' |
|
|
'See --weight_only_precision to set the precision') |
|
|
parser.add_argument( |
|
|
'--weight_only_precision', |
|
|
const='int8', |
|
|
type=str, |
|
|
nargs='?', |
|
|
default='int8', |
|
|
choices=['int8', 'int4'], |
|
|
help= |
|
|
'Define the precision for the weights when using weight-only quantization.' |
|
|
'You must also use --use_weight_only for that argument to have an impact.' |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def get_decoder_config(model_args, dtype: str, logits_dtype: str, quant_algo: QuantAlgo) -> dict: |
|
|
return { |
|
|
'architecture': "DecoderModel", |
|
|
'dtype': dtype, |
|
|
'logits_dtype': logits_dtype, |
|
|
'num_hidden_layers': model_args.n_layers_dec, |
|
|
'num_attention_heads': model_args.n_head, |
|
|
'hidden_size': model_args.d_model, |
|
|
'norm_epsilon': 1e-5, |
|
|
'vocab_size': model_args.odim, |
|
|
'hidden_act': "gelu", |
|
|
'use_parallel_embedding': False, |
|
|
'embedding_sharding_dim': 0, |
|
|
'max_position_embeddings': model_args.pe_maxlen, |
|
|
'use_prompt_tuning': False, |
|
|
'head_size': model_args.d_model // model_args.n_head, |
|
|
'has_position_embedding': True, |
|
|
'layernorm_type': LayerNormType.LayerNorm, |
|
|
'has_attention_qkvo_bias': True, |
|
|
'has_mlp_bias': True, |
|
|
'has_model_final_layernorm': True, |
|
|
'has_embedding_layernorm': False, |
|
|
'has_embedding_scale': True, |
|
|
'ffn_hidden_size': 4 * model_args.d_model, |
|
|
'q_scaling': 1.0, |
|
|
'layernorm_position': LayerNormPositionType.pre_layernorm, |
|
|
'relative_attention': False, |
|
|
'max_distance': 0, |
|
|
'num_buckets': 0, |
|
|
'model_type': 'whisper', |
|
|
'rescale_before_lm_head': False, |
|
|
'encoder_hidden_size': model_args.d_model, |
|
|
'encoder_num_heads': model_args.n_head, |
|
|
'encoder_head_size': None, |
|
|
'skip_cross_kv': False, |
|
|
'quantization': { |
|
|
'quant_algo': quant_algo |
|
|
}, |
|
|
} |
|
|
|
|
|
def remap_state_dict(original_state_dict): |
|
|
new_state_dict = {} |
|
|
for key, value in original_state_dict.items(): |
|
|
if key.startswith("decoder."): |
|
|
new_key = key |
|
|
|
|
|
new_key = new_key.replace("decoder.tgt_word_emb.", "decoder.token_embedding.") |
|
|
new_key = new_key.replace("decoder.layer_stack.", "decoder.blocks.") |
|
|
new_key = new_key.replace("decoder.layer_norm_out.", "decoder.ln.") |
|
|
new_key = new_key.replace("decoder.tgt_word_prj.", "decoder.output_projection.") |
|
|
|
|
|
|
|
|
new_key = new_key.replace(".self_attn_norm.", ".attn_ln.") |
|
|
new_key = new_key.replace(".self_attn.", ".attn.") |
|
|
new_key = new_key.replace(".cross_attn_norm.", ".cross_attn_ln.") |
|
|
new_key = new_key.replace(".cross_attn.", ".cross_attn.") |
|
|
new_key = new_key.replace(".mlp_norm.", ".mlp_ln.") |
|
|
|
|
|
|
|
|
new_key = new_key.replace(".mlp.w_1.", ".mlp.0.") |
|
|
new_key = new_key.replace(".mlp.w_2.", ".mlp.2.") |
|
|
|
|
|
|
|
|
new_key = new_key.replace(".w_qs.", ".query.") |
|
|
new_key = new_key.replace(".w_ks.", ".key.") |
|
|
new_key = new_key.replace(".w_vs.", ".value.") |
|
|
new_key = new_key.replace(".fc.", ".out.") |
|
|
|
|
|
new_state_dict[new_key] = value |
|
|
|
|
|
|
|
|
if "decoder.positional_encoding.pe" in original_state_dict: |
|
|
new_state_dict["decoder.positional_embedding"] = original_state_dict["decoder.positional_encoding.pe"].squeeze(0) |
|
|
|
|
|
return new_state_dict |
|
|
|
|
|
|
|
|
def convert_firered_decoder(model_args, model_params, quant_algo: str = None): |
|
|
weights = {} |
|
|
|
|
|
|
|
|
|
|
|
weights['transformer.vocab_embedding.weight'] = model_params['decoder.token_embedding.weight'] |
|
|
weights['lm_head.weight'] = model_params['decoder.output_projection.weight'] |
|
|
weights['transformer.position_embedding.weight'] = model_params['decoder.positional_embedding'] |
|
|
|
|
|
for i in range(model_args.n_layers_dec): |
|
|
trtllm_layer_name_prefix = f'transformer.layers.{i}' |
|
|
|
|
|
|
|
|
q_w = model_params[f'decoder.blocks.{i}.attn.query.weight'] |
|
|
k_w = model_params[f'decoder.blocks.{i}.attn.key.weight'] |
|
|
v_w = model_params[f'decoder.blocks.{i}.attn.value.weight'] |
|
|
weights[f'{trtllm_layer_name_prefix}.self_attention.qkv.weight'] = torch.cat([q_w, k_w, v_w], dim=0) |
|
|
|
|
|
q_b = model_params[f'decoder.blocks.{i}.attn.query.bias'] |
|
|
|
|
|
k_b = torch.zeros_like(q_b) |
|
|
v_b = model_params[f'decoder.blocks.{i}.attn.value.bias'] |
|
|
weights[f'{trtllm_layer_name_prefix}.self_attention.qkv.bias'] = torch.cat([q_b, k_b, v_b], dim=0) |
|
|
|
|
|
weights[f'{trtllm_layer_name_prefix}.self_attention.dense.weight'] = model_params[f'decoder.blocks.{i}.attn.out.weight'] |
|
|
weights[f'{trtllm_layer_name_prefix}.self_attention.dense.bias'] = model_params[f'decoder.blocks.{i}.attn.out.bias'] |
|
|
weights[f'{trtllm_layer_name_prefix}.self_attention_layernorm.weight'] = model_params[f'decoder.blocks.{i}.attn_ln.weight'] |
|
|
weights[f'{trtllm_layer_name_prefix}.self_attention_layernorm.bias'] = model_params[f'decoder.blocks.{i}.attn_ln.bias'] |
|
|
|
|
|
|
|
|
q_w = model_params[f'decoder.blocks.{i}.cross_attn.query.weight'] |
|
|
k_w = model_params[f'decoder.blocks.{i}.cross_attn.key.weight'] |
|
|
v_w = model_params[f'decoder.blocks.{i}.cross_attn.value.weight'] |
|
|
weights[f'{trtllm_layer_name_prefix}.cross_attention.qkv.weight'] = torch.cat([q_w, k_w, v_w], dim=0) |
|
|
|
|
|
q_b = model_params[f'decoder.blocks.{i}.cross_attn.query.bias'] |
|
|
|
|
|
k_b = torch.zeros_like(q_b) |
|
|
v_b = model_params[f'decoder.blocks.{i}.cross_attn.value.bias'] |
|
|
weights[f'{trtllm_layer_name_prefix}.cross_attention.qkv.bias'] = torch.cat([q_b, k_b, v_b], dim=0) |
|
|
|
|
|
weights[f'{trtllm_layer_name_prefix}.cross_attention.dense.weight'] = model_params[f'decoder.blocks.{i}.cross_attn.out.weight'] |
|
|
weights[f'{trtllm_layer_name_prefix}.cross_attention.dense.bias'] = model_params[f'decoder.blocks.{i}.cross_attn.out.bias'] |
|
|
weights[f'{trtllm_layer_name_prefix}.cross_attention_layernorm.weight'] = model_params[f'decoder.blocks.{i}.cross_attn_ln.weight'] |
|
|
weights[f'{trtllm_layer_name_prefix}.cross_attention_layernorm.bias'] = model_params[f'decoder.blocks.{i}.cross_attn_ln.bias'] |
|
|
|
|
|
|
|
|
weights[f'{trtllm_layer_name_prefix}.mlp.fc.weight'] = model_params[f'decoder.blocks.{i}.mlp.0.weight'] |
|
|
weights[f'{trtllm_layer_name_prefix}.mlp.fc.bias'] = model_params[f'decoder.blocks.{i}.mlp.0.bias'] |
|
|
weights[f'{trtllm_layer_name_prefix}.mlp.proj.weight'] = model_params[f'decoder.blocks.{i}.mlp.2.weight'] |
|
|
weights[f'{trtllm_layer_name_prefix}.mlp.proj.bias'] = model_params[f'decoder.blocks.{i}.mlp.2.bias'] |
|
|
weights[f'{trtllm_layer_name_prefix}.mlp_layernorm.weight'] = model_params[f'decoder.blocks.{i}.mlp_ln.weight'] |
|
|
weights[f'{trtllm_layer_name_prefix}.mlp_layernorm.bias'] = model_params[f'decoder.blocks.{i}.mlp_ln.bias'] |
|
|
|
|
|
weights['transformer.ln_f.weight'] = model_params['decoder.ln.weight'] |
|
|
weights['transformer.ln_f.bias'] = model_params['decoder.ln.bias'] |
|
|
|
|
|
if quant_algo is not None: |
|
|
return weight_only_quantize_dict(weights, quant_algo=quant_algo) |
|
|
return weights |
|
|
|
|
|
|
|
|
if __name__ == '__main__': |
|
|
print(f"Using TensorRT-LLM version: {tensorrt_llm.__version__}") |
|
|
args = parse_arguments() |
|
|
tik = time.time() |
|
|
|
|
|
if not os.path.exists(args.output_dir): |
|
|
os.makedirs(args.output_dir) |
|
|
|
|
|
quant_algo = None |
|
|
if args.use_weight_only and args.weight_only_precision == 'int8': |
|
|
quant_algo = QuantAlgo.W8A16 |
|
|
elif args.use_weight_only and args.weight_only_precision == 'int4': |
|
|
quant_algo = QuantAlgo.W4A16 |
|
|
|
|
|
|
|
|
package = torch.load(args.model_path, map_location='cpu', weights_only=False) |
|
|
model_args = package["args"] |
|
|
original_state_dict = package["model_state_dict"] |
|
|
print(f"Successfully loaded checkpoint from {args.model_path}") |
|
|
print("Original model args:", model_args) |
|
|
|
|
|
|
|
|
remapped_state_dict = remap_state_dict(original_state_dict) |
|
|
|
|
|
|
|
|
tensor_dtype = getattr(torch, args.dtype) |
|
|
for key, value in remapped_state_dict.items(): |
|
|
remapped_state_dict[key] = value.to(tensor_dtype) |
|
|
|
|
|
|
|
|
print("Converting decoder checkpoint...") |
|
|
decoder_config = get_decoder_config(model_args, args.dtype, args.logits_dtype, quant_algo) |
|
|
decoder_weights = convert_firered_decoder(model_args, remapped_state_dict, quant_algo) |
|
|
|
|
|
|
|
|
decoder_save_dir = os.path.join(args.output_dir, "decoder") |
|
|
if not os.path.exists(decoder_save_dir): |
|
|
os.makedirs(decoder_save_dir) |
|
|
|
|
|
with open(os.path.join(decoder_save_dir, 'config.json'), 'w') as f: |
|
|
json.dump(decoder_config, f, indent=4) |
|
|
|
|
|
save_file(decoder_weights, os.path.join(decoder_save_dir, f'rank0.safetensors')) |
|
|
|
|
|
tok = time.time() |
|
|
t = time.strftime('%H:%M:%S', time.gmtime(tok - tik)) |
|
|
print(f'Checkpoint successfully converted and saved to {args.output_dir}.') |
|
|
print(f'Total time of converting checkpoints: {t}') |
|
|
|