aspctu's picture
Upload folder using huggingface_hub
5000658 verified
from pathlib import Path
import torch
from tensorrt_llm import logger
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import split
def get_tllm_linear_weight(weight,
prefix,
bias=None,
use_weight_only=False,
plugin_weight_only_quant_type=torch.int8,
postfix='weight'):
results = {}
if use_weight_only:
v = weight.t().contiguous().cpu()
processed_torch_weights, torch_weight_scales = \
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
v, plugin_weight_only_quant_type)
results[prefix + postfix] = processed_torch_weights
results[prefix + 'per_channel_scale'] = torch_weight_scales
else:
results[prefix + postfix] = weight.contiguous()
if bias is not None:
results[prefix + 'bias'] = bias
return results
def load_medusa_hf(medusa_path: str,
num_medusa_heads: int,
num_medusa_layers: int,
mapping=Mapping(),
dtype='float32',
use_weight_only=False,
plugin_weight_only_quant_type=None):
logger.info("Loading Medusa heads' weights ...")
is_ckpt_safetensors = False
ckpt_file = Path(medusa_path) / "medusa_lm_head.pt"
if not ckpt_file.exists():
ckpt_file = Path(medusa_path) / "medusa_lm_head.safetensors"
is_ckpt_safetensors = True
if is_ckpt_safetensors:
logger.info("Safetensors Found ...")
from safetensors.torch import load_file
state_dict = load_file(ckpt_file)
else:
state_dict = torch.load(ckpt_file, map_location="cpu")
torch_dtype = str_dtype_to_torch(dtype)
weights = {}
for h in range(num_medusa_heads):
for l in range(num_medusa_layers):
w = state_dict[f"{h}.{l}.linear.weight"].clone().to(torch_dtype)
split_v = split(w, mapping.tp_size, mapping.tp_rank)
weights.update(
get_tllm_linear_weight(
split_v, f'medusa_heads.{h}.medusa_layers.{l}.linear.',
None, use_weight_only, plugin_weight_only_quant_type))
b = state_dict[f"{h}.{l}.linear.bias"].clone().to(torch_dtype)
weights['medusa_heads.{}.medusa_layers.{}.linear.bias'.format(
h, l)] = split(b, mapping.tp_size, mapping.tp_rank)
lm = state_dict[f"{h}.{num_medusa_layers}.weight"].clone().to(
torch_dtype) # LM Head
weights['medusa_heads.{}.lm_head.weight'.format(h)] = split(
lm, mapping.tp_size, mapping.tp_rank)
return weights