|
|
import copy |
|
|
from argparse import Namespace |
|
|
from functools import cmp_to_key |
|
|
from typing import List |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from tensorrt_llm.logger import logger |
|
|
|
|
|
|
|
|
def path_sorter(a, b): |
|
|
for i in range(min(len(a), len(b))): |
|
|
if a[i] != b[i]: |
|
|
return -1 if a[i] < b[i] else 1 |
|
|
return 0 |
|
|
|
|
|
|
|
|
path_sorting_key = cmp_to_key(path_sorter) |
|
|
|
|
|
|
|
|
def expand_choices_if_needed(medusa_choices: List[List[int]]): |
|
|
""" |
|
|
Do a simple check to see if the given choices are path-only or vanilla. |
|
|
""" |
|
|
assert len(medusa_choices) > 0 |
|
|
for c in medusa_choices: |
|
|
if len(c) > 1: |
|
|
try: |
|
|
_ = medusa_choices.index( |
|
|
[c[0]]) |
|
|
logger.debug( |
|
|
"Detected vanilla-style of Medusa choices. No need to expand." |
|
|
) |
|
|
return medusa_choices |
|
|
except ValueError: |
|
|
logger.debug( |
|
|
"Detected path-only style of Medusa choices. Expanding ...") |
|
|
break |
|
|
expanded_choices = set() |
|
|
for c in medusa_choices: |
|
|
cur = () |
|
|
for n in c: |
|
|
cur = (*cur, n) |
|
|
expanded_choices.add(cur) |
|
|
expanded_choices = [list(c) for c in expanded_choices] |
|
|
return expanded_choices |
|
|
|
|
|
|
|
|
def get_packed_mask(num_medusa_tokens, medusa_mask, max_medusa_tokens=None): |
|
|
max_medusa_tokens = num_medusa_tokens if max_medusa_tokens is None else max_medusa_tokens |
|
|
num_packed_masks = (max_medusa_tokens + 1 + 32 - 1) // 32 |
|
|
medusa_packed_mask = torch.zeros((num_medusa_tokens + 1, num_packed_masks), |
|
|
dtype=torch.int32) |
|
|
for token_idx in range(num_medusa_tokens + 1): |
|
|
if token_idx == 0: |
|
|
medusa_packed_mask[0, 0] = 1 |
|
|
else: |
|
|
mask_list = medusa_mask[token_idx - 1, :].tolist() |
|
|
|
|
|
mask_list.insert(0, True) |
|
|
|
|
|
mask_str_list = [str(int(val)) for val in mask_list] |
|
|
mask_str_list.reverse() |
|
|
|
|
|
for mask_idx in range(num_packed_masks): |
|
|
if mask_idx * 32 >= len(mask_str_list): |
|
|
break |
|
|
mask_32bits_str = ''.join(mask_str_list[-(mask_idx + 1) * 32: |
|
|
(-mask_idx * 32 - 1)] + |
|
|
[mask_str_list[(-mask_idx * 32 - 1)]]) |
|
|
valid_num_bits = len(mask_32bits_str) |
|
|
first_bit1 = mask_32bits_str[0] == '1' |
|
|
mask_31bits_str = mask_32bits_str[1:] |
|
|
mask_31bits = 0 if mask_31bits_str == "" else int( |
|
|
mask_31bits_str, 2) |
|
|
if valid_num_bits == 32: |
|
|
mask_32bits = mask_31bits - first_bit1 * (2**( |
|
|
valid_num_bits - 1)) |
|
|
else: |
|
|
mask_32bits = mask_31bits + first_bit1 * (2**( |
|
|
valid_num_bits - 1)) |
|
|
medusa_packed_mask[token_idx, mask_idx] = mask_32bits |
|
|
return medusa_packed_mask |
|
|
|
|
|
|
|
|
def choices_2_paths(num_medusa_heads, choices): |
|
|
paths = {} |
|
|
all_paths = {} |
|
|
level_counts = [0] * num_medusa_heads |
|
|
choices.sort(key=len, reverse=True) |
|
|
for c in choices: |
|
|
k = ":".join([str(ci) for ci in c]) |
|
|
if k not in all_paths: |
|
|
paths[k] = c |
|
|
for i in range(len(c)): |
|
|
k = ":".join([str(ci) for ci in c[:i + 1]]) |
|
|
if k not in all_paths: |
|
|
all_paths[k] = c[:i + 1] |
|
|
level_counts[i] += 1 |
|
|
return list(paths.values()), level_counts, paths, all_paths |
|
|
|
|
|
|
|
|
def get_medusa_topks(num_medusa_heads, paths): |
|
|
medusa_topks = [0] * num_medusa_heads |
|
|
for p in paths: |
|
|
for i, k in enumerate(p): |
|
|
medusa_topks[i] = max(medusa_topks[i], k + 1) |
|
|
return medusa_topks |
|
|
|
|
|
|
|
|
def get_medusa_tree(num_medusa_heads, medusa_topks, level_counts, paths): |
|
|
cum_topks = np.cumsum([0] + medusa_topks) |
|
|
cum_level_counts = np.cumsum([0] + level_counts) |
|
|
tree_paths = copy.deepcopy(paths) |
|
|
medusa_tree_ids = list(np.arange(medusa_topks[0])) |
|
|
medusa_position_offsets = [0] * medusa_topks[0] |
|
|
for i in range(1, num_medusa_heads): |
|
|
last_prefix = "-1" |
|
|
last = -1 |
|
|
c = -1 |
|
|
for pi, p in enumerate(paths): |
|
|
if i < len(p): |
|
|
prefix_str = ":".join([str(k) for k in p[:i]]) |
|
|
if last_prefix != prefix_str or last != p[i]: |
|
|
|
|
|
medusa_position_offsets.append(i) |
|
|
medusa_tree_ids.append(p[i] + cum_topks[i]) |
|
|
last_prefix = prefix_str |
|
|
last = p[i] |
|
|
c += 1 |
|
|
tree_paths[pi][i] = cum_level_counts[i] + c |
|
|
return medusa_tree_ids, medusa_position_offsets, tree_paths |
|
|
|
|
|
|
|
|
def get_medusa_mask(medusa_tree_ids, medusa_paths): |
|
|
medusa_mask = torch.zeros((len(medusa_tree_ids), len(medusa_tree_ids))) |
|
|
medusa_mask[:, 0] = 1 |
|
|
for p in medusa_paths: |
|
|
for i, idx in enumerate(p): |
|
|
if idx < 0: |
|
|
continue |
|
|
for j in range(i + 1): |
|
|
medusa_mask[idx, p[j]] = 1 |
|
|
return medusa_mask |
|
|
|
|
|
|
|
|
def _medusa_setup(choices_or_paths, num_medusa_heads=None): |
|
|
choices = copy.deepcopy(choices_or_paths) |
|
|
sorted_choices = sorted(choices, key=path_sorting_key) |
|
|
if num_medusa_heads is None: |
|
|
num_medusa_heads = max([len(c) for c in sorted_choices]) |
|
|
paths, level_counts, _, _ = choices_2_paths(num_medusa_heads, |
|
|
sorted_choices) |
|
|
paths = sorted(paths, key=path_sorting_key) |
|
|
medusa_topks = get_medusa_topks(num_medusa_heads, paths) |
|
|
medusa_tree_ids, medusa_position_offsets, tree_paths = get_medusa_tree( |
|
|
num_medusa_heads, medusa_topks, level_counts, paths) |
|
|
|
|
|
num_medusa_tokens = len(medusa_tree_ids) |
|
|
|
|
|
medusa_paths = [] |
|
|
for p in tree_paths: |
|
|
medusa_paths.append( |
|
|
torch.tensor([-1] + p + ([-2] * (num_medusa_heads - len(p))))) |
|
|
medusa_topks = torch.tensor(medusa_topks) |
|
|
medusa_paths = torch.stack(medusa_paths) + 1 |
|
|
medusa_tree_ids = torch.tensor([-1] + medusa_tree_ids) + 1 |
|
|
medusa_position_offsets = torch.tensor([-1] + medusa_position_offsets) + 1 |
|
|
medusa_mask = get_medusa_mask(medusa_tree_ids, medusa_paths) |
|
|
medusa_packed_mask = get_packed_mask(num_medusa_tokens, medusa_mask[1:, 1:]) |
|
|
|
|
|
return Namespace( |
|
|
medusa_mask=medusa_mask.cuda(), |
|
|
medusa_packed_mask=medusa_packed_mask.cuda(), |
|
|
medusa_topks=medusa_topks.cuda(), |
|
|
medusa_paths=medusa_paths.cuda(), |
|
|
medusa_tree_ids=medusa_tree_ids.cuda(), |
|
|
medusa_position_offsets=medusa_position_offsets.cuda(), |
|
|
) |
|
|
|