|
|
from pathlib import Path |
|
|
|
|
|
import torch |
|
|
|
|
|
from tensorrt_llm import logger |
|
|
from tensorrt_llm._utils import str_dtype_to_torch |
|
|
from tensorrt_llm.mapping import Mapping |
|
|
from tensorrt_llm.models.convert_utils import split |
|
|
|
|
|
|
|
|
def get_tllm_linear_weight(weight, |
|
|
prefix, |
|
|
bias=None, |
|
|
use_weight_only=False, |
|
|
plugin_weight_only_quant_type=torch.int8, |
|
|
postfix='weight'): |
|
|
results = {} |
|
|
if use_weight_only: |
|
|
v = weight.t().contiguous().cpu() |
|
|
processed_torch_weights, torch_weight_scales = \ |
|
|
torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix( |
|
|
v, plugin_weight_only_quant_type) |
|
|
results[prefix + postfix] = processed_torch_weights |
|
|
results[prefix + 'per_channel_scale'] = torch_weight_scales |
|
|
else: |
|
|
results[prefix + postfix] = weight.contiguous() |
|
|
|
|
|
if bias is not None: |
|
|
results[prefix + 'bias'] = bias |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
def load_medusa_hf(medusa_path: str, |
|
|
num_medusa_heads: int, |
|
|
num_medusa_layers: int, |
|
|
mapping=Mapping(), |
|
|
dtype='float32', |
|
|
use_weight_only=False, |
|
|
plugin_weight_only_quant_type=None): |
|
|
logger.info("Loading Medusa heads' weights ...") |
|
|
is_ckpt_safetensors = False |
|
|
|
|
|
ckpt_file = Path(medusa_path) / "medusa_lm_head.pt" |
|
|
if not ckpt_file.exists(): |
|
|
ckpt_file = Path(medusa_path) / "medusa_lm_head.safetensors" |
|
|
is_ckpt_safetensors = True |
|
|
|
|
|
if is_ckpt_safetensors: |
|
|
logger.info("Safetensors Found ...") |
|
|
from safetensors.torch import load_file |
|
|
state_dict = load_file(ckpt_file) |
|
|
else: |
|
|
state_dict = torch.load(ckpt_file, map_location="cpu") |
|
|
|
|
|
torch_dtype = str_dtype_to_torch(dtype) |
|
|
weights = {} |
|
|
|
|
|
for h in range(num_medusa_heads): |
|
|
for l in range(num_medusa_layers): |
|
|
w = state_dict[f"{h}.{l}.linear.weight"].clone().to(torch_dtype) |
|
|
|
|
|
split_v = split(w, mapping.tp_size, mapping.tp_rank) |
|
|
weights.update( |
|
|
get_tllm_linear_weight( |
|
|
split_v, f'medusa_heads.{h}.medusa_layers.{l}.linear.', |
|
|
None, use_weight_only, plugin_weight_only_quant_type)) |
|
|
|
|
|
b = state_dict[f"{h}.{l}.linear.bias"].clone().to(torch_dtype) |
|
|
|
|
|
weights['medusa_heads.{}.medusa_layers.{}.linear.bias'.format( |
|
|
h, l)] = split(b, mapping.tp_size, mapping.tp_rank) |
|
|
|
|
|
lm = state_dict[f"{h}.{num_medusa_layers}.weight"].clone().to( |
|
|
torch_dtype) |
|
|
|
|
|
weights['medusa_heads.{}.lm_head.weight'.format(h)] = split( |
|
|
lm, mapping.tp_size, mapping.tp_rank) |
|
|
|
|
|
return weights |
|
|
|