File size: 7,001 Bytes
5000658
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import copy
from argparse import Namespace
from functools import cmp_to_key
from typing import List

import numpy as np
import torch

from tensorrt_llm.logger import logger


def path_sorter(a, b):
    for i in range(min(len(a), len(b))):
        if a[i] != b[i]:
            return -1 if a[i] < b[i] else 1
    return 0  # shouldn't reach


path_sorting_key = cmp_to_key(path_sorter)


def expand_choices_if_needed(medusa_choices: List[List[int]]):
    """
    Do a simple check to see if the given choices are path-only or vanilla.
    """
    assert len(medusa_choices) > 0
    for c in medusa_choices:
        if len(c) > 1:
            try:
                _ = medusa_choices.index(
                    [c[0]])  # find the first parent of current path
                logger.debug(
                    "Detected vanilla-style of Medusa choices. No need to expand."
                )
                return medusa_choices  # if found, just return assuming it is already expanded
            except ValueError:
                logger.debug(
                    "Detected path-only style of Medusa choices. Expanding ...")
                break
    expanded_choices = set()
    for c in medusa_choices:
        cur = ()
        for n in c:
            cur = (*cur, n)
            expanded_choices.add(cur)
    expanded_choices = [list(c) for c in expanded_choices]
    return expanded_choices


def get_packed_mask(num_medusa_tokens, medusa_mask, max_medusa_tokens=None):
    max_medusa_tokens = num_medusa_tokens if max_medusa_tokens is None else max_medusa_tokens
    num_packed_masks = (max_medusa_tokens + 1 + 32 - 1) // 32
    medusa_packed_mask = torch.zeros((num_medusa_tokens + 1, num_packed_masks),
                                     dtype=torch.int32)
    for token_idx in range(num_medusa_tokens + 1):
        if token_idx == 0:
            medusa_packed_mask[0, 0] = 1
        else:
            mask_list = medusa_mask[token_idx - 1, :].tolist()
            # insert 1 as there is one extra new token from the original lm head.
            mask_list.insert(0, True)
            # convert binary bits into 4 int32_t
            mask_str_list = [str(int(val)) for val in mask_list]
            mask_str_list.reverse()

            for mask_idx in range(num_packed_masks):
                if mask_idx * 32 >= len(mask_str_list):
                    break
                mask_32bits_str = ''.join(mask_str_list[-(mask_idx + 1) * 32:
                                                        (-mask_idx * 32 - 1)] +
                                          [mask_str_list[(-mask_idx * 32 - 1)]])
                valid_num_bits = len(mask_32bits_str)
                first_bit1 = mask_32bits_str[0] == '1'
                mask_31bits_str = mask_32bits_str[1:]
                mask_31bits = 0 if mask_31bits_str == "" else int(
                    mask_31bits_str, 2)
                if valid_num_bits == 32:
                    mask_32bits = mask_31bits - first_bit1 * (2**(
                        valid_num_bits - 1))
                else:
                    mask_32bits = mask_31bits + first_bit1 * (2**(
                        valid_num_bits - 1))
                medusa_packed_mask[token_idx, mask_idx] = mask_32bits
    return medusa_packed_mask


def choices_2_paths(num_medusa_heads, choices):
    paths = {}
    all_paths = {}
    level_counts = [0] * num_medusa_heads
    choices.sort(key=len, reverse=True)
    for c in choices:
        k = ":".join([str(ci) for ci in c])
        if k not in all_paths:
            paths[k] = c
        for i in range(len(c)):
            k = ":".join([str(ci) for ci in c[:i + 1]])
            if k not in all_paths:
                all_paths[k] = c[:i + 1]
                level_counts[i] += 1
    return list(paths.values()), level_counts, paths, all_paths


def get_medusa_topks(num_medusa_heads, paths):
    medusa_topks = [0] * num_medusa_heads
    for p in paths:
        for i, k in enumerate(p):
            medusa_topks[i] = max(medusa_topks[i], k + 1)
    return medusa_topks


def get_medusa_tree(num_medusa_heads, medusa_topks, level_counts, paths):
    cum_topks = np.cumsum([0] + medusa_topks)
    cum_level_counts = np.cumsum([0] + level_counts)
    tree_paths = copy.deepcopy(paths)
    medusa_tree_ids = list(np.arange(medusa_topks[0]))
    medusa_position_offsets = [0] * medusa_topks[0]
    for i in range(1, num_medusa_heads):
        last_prefix = "-1"
        last = -1
        c = -1
        for pi, p in enumerate(paths):
            if i < len(p):
                prefix_str = ":".join([str(k) for k in p[:i]])
                if last_prefix != prefix_str or last != p[i]:
                    # new path
                    medusa_position_offsets.append(i)
                    medusa_tree_ids.append(p[i] + cum_topks[i])
                    last_prefix = prefix_str
                    last = p[i]
                    c += 1
                tree_paths[pi][i] = cum_level_counts[i] + c
    return medusa_tree_ids, medusa_position_offsets, tree_paths


def get_medusa_mask(medusa_tree_ids, medusa_paths):
    medusa_mask = torch.zeros((len(medusa_tree_ids), len(medusa_tree_ids)))
    medusa_mask[:, 0] = 1
    for p in medusa_paths:
        for i, idx in enumerate(p):
            if idx < 0:
                continue
            for j in range(i + 1):
                medusa_mask[idx, p[j]] = 1
    return medusa_mask


def _medusa_setup(choices_or_paths, num_medusa_heads=None):
    choices = copy.deepcopy(choices_or_paths)
    sorted_choices = sorted(choices, key=path_sorting_key)
    if num_medusa_heads is None:
        num_medusa_heads = max([len(c) for c in sorted_choices])
    paths, level_counts, _, _ = choices_2_paths(num_medusa_heads,
                                                sorted_choices)
    paths = sorted(paths, key=path_sorting_key)
    medusa_topks = get_medusa_topks(num_medusa_heads, paths)
    medusa_tree_ids, medusa_position_offsets, tree_paths = get_medusa_tree(
        num_medusa_heads, medusa_topks, level_counts, paths)

    num_medusa_tokens = len(medusa_tree_ids)
    # now do the padding before converting to torch.Tensor
    medusa_paths = []
    for p in tree_paths:
        medusa_paths.append(
            torch.tensor([-1] + p + ([-2] * (num_medusa_heads - len(p)))))
    medusa_topks = torch.tensor(medusa_topks)
    medusa_paths = torch.stack(medusa_paths) + 1
    medusa_tree_ids = torch.tensor([-1] + medusa_tree_ids) + 1
    medusa_position_offsets = torch.tensor([-1] + medusa_position_offsets) + 1
    medusa_mask = get_medusa_mask(medusa_tree_ids, medusa_paths)
    medusa_packed_mask = get_packed_mask(num_medusa_tokens, medusa_mask[1:, 1:])

    return Namespace(
        medusa_mask=medusa_mask.cuda(),
        medusa_packed_mask=medusa_packed_mask.cuda(),
        medusa_topks=medusa_topks.cuda(),
        medusa_paths=medusa_paths.cuda(),
        medusa_tree_ids=medusa_tree_ids.cuda(),
        medusa_position_offsets=medusa_position_offsets.cuda(),
    )