File size: 2,822 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
from pathlib import Path

import torch

from tensorrt_llm import logger
from tensorrt_llm._utils import str_dtype_to_torch
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models.convert_utils import split


def get_tllm_linear_weight(weight,
                           prefix,
                           bias=None,
                           use_weight_only=False,
                           plugin_weight_only_quant_type=torch.int8,
                           postfix='weight'):
    results = {}
    if use_weight_only:
        v = weight.t().contiguous().cpu()
        processed_torch_weights, torch_weight_scales = \
            torch.ops.trtllm.symmetric_quantize_last_axis_of_batched_matrix(
                v, plugin_weight_only_quant_type)
        results[prefix + postfix] = processed_torch_weights
        results[prefix + 'per_channel_scale'] = torch_weight_scales
    else:
        results[prefix + postfix] = weight.contiguous()

    if bias is not None:
        results[prefix + 'bias'] = bias

    return results


def load_medusa_hf(medusa_path: str,
                   num_medusa_heads: int,
                   num_medusa_layers: int,
                   mapping=Mapping(),
                   dtype='float32',
                   use_weight_only=False,
                   plugin_weight_only_quant_type=None):
    logger.info("Loading Medusa heads' weights ...")
    is_ckpt_safetensors = False

    ckpt_file = Path(medusa_path) / "medusa_lm_head.pt"
    if not ckpt_file.exists():
        ckpt_file = Path(medusa_path) / "medusa_lm_head.safetensors"
        is_ckpt_safetensors = True

    if is_ckpt_safetensors:
        logger.info("Safetensors Found ...")
        from safetensors.torch import load_file
        state_dict = load_file(ckpt_file)
    else:
        state_dict = torch.load(ckpt_file, map_location="cpu")

    torch_dtype = str_dtype_to_torch(dtype)
    weights = {}

    for h in range(num_medusa_heads):
        for l in range(num_medusa_layers):
            w = state_dict[f"{h}.{l}.linear.weight"].clone().to(torch_dtype)

            split_v = split(w, mapping.tp_size, mapping.tp_rank)
            weights.update(
                get_tllm_linear_weight(
                    split_v, f'medusa_heads.{h}.medusa_layers.{l}.linear.',
                    None, use_weight_only, plugin_weight_only_quant_type))

            b = state_dict[f"{h}.{l}.linear.bias"].clone().to(torch_dtype)

            weights['medusa_heads.{}.medusa_layers.{}.linear.bias'.format(
                h, l)] = split(b, mapping.tp_size, mapping.tp_rank)

        lm = state_dict[f"{h}.{num_medusa_layers}.weight"].clone().to(
            torch_dtype)  # LM Head

        weights['medusa_heads.{}.lm_head.weight'.format(h)] = split(
            lm, mapping.tp_size, mapping.tp_rank)

    return weights