leideng's picture
download
raw
4.38 kB
# -*- coding: utf-8 -*-
import logging
import os
from typing import List, Tuple
import numpy as np
from torch.utils.cpp_extension import load
logger = logging.getLogger(__name__)
_abs_path = os.path.dirname(os.path.abspath(__file__))
ngram_cache_cpp = load(
name="ngram_cache_cpp",
sources=[
f"{_abs_path}/ngram_cache_binding.cpp",
f"{_abs_path}/ngram.cpp",
],
extra_cflags=["-O3", "-std=c++20"],
)
class NgramCache:
def __init__(
self,
branch_length=18,
min_match_window_size=1,
max_match_window_size=10,
min_bfs_breadth=1,
max_bfs_breadth=8,
draft_token_num=8,
match_type="BFS",
capacity=1000000,
):
param = ngram_cache_cpp.Param()
param.branch_length = branch_length
param.min_match_window_size = min_match_window_size
param.max_match_window_size = max_match_window_size
param.min_bfs_breadth = min_bfs_breadth
param.max_bfs_breadth = max_bfs_breadth
param.draft_token_num = draft_token_num
param.match_type = match_type
self.cache = ngram_cache_cpp.Ngram(capacity, param)
self.default_mask = np.ones((1, 1), dtype=np.int64)
self.draft_token_num = draft_token_num
def batch_put(self, batch_tokens: List[List[int]]):
self.cache.asyncInsert(batch_tokens)
def synchronize(self):
self.cache.synchronize()
def reset(self):
self.cache.reset()
def batch_get(self, batch_tokens: List[List[int]]) -> Tuple[np.ndarray, np.ndarray]:
result = self.cache.batchMatch(batch_tokens)
return np.array(result.token), np.array(result.mask)
def leaf_paths_from_mask(
self, tokens: List[int], tree_mask: List[List[int]]
) -> List[List[int]]:
"""
Find all leaf paths according to the binary tree_mask (i.e., paths that are not prefixes of any other path).
Args:
mask : List[List[int]] # nxn binary matrix
tokens : List[int] # token list corresponding to columns
Returns:
List[List[int]] # token lists of only the leaf paths, preserving their order of appearance
"""
row_sets = [
(i, {idx for idx, v in enumerate(row) if v == 1})
for i, row in enumerate(tree_mask)
]
leaf_sets = []
leaf_rows = []
for i, cur_set in reversed(row_sets):
if any(cur_set <= kept for kept in leaf_sets):
continue
leaf_sets.append(cur_set)
leaf_rows.append(i)
leaf_rows.reverse()
result = []
for r in leaf_rows:
path = [tokens[col] for col in range(len(tokens)) if tree_mask[r][col] == 1]
result.append(path)
return result
def debug_result(
self, decoding_ids: np.ndarray, decoding_masks: np.ndarray, tokenizer=None
):
decoding_ids = decoding_ids.reshape(-1, self.draft_token_num)
decoding_masks = decoding_masks.reshape(
-1, self.draft_token_num, self.draft_token_num
)
logger.info(f"\n{decoding_ids=}\n{decoding_masks=}")
for i in range(decoding_ids.shape[0]):
leaf_paths = self.leaf_paths_from_mask(
decoding_ids[i].tolist(), decoding_masks[i].tolist()
)
if tokenizer is None:
logger.info(f"draft path {i}: {leaf_paths}")
else:
logger.info(f"result {i}:")
for leaf_path in leaf_paths:
logger.info(
f"draft path {i}: {leaf_path} -> {tokenizer.decode(leaf_path, ensure_ascii=False)}"
)
# main function
if __name__ == "__main__":
format = f"%(levelname)s %(asctime)s %(filename)s:%(lineno)d] %(message)s"
logging.basicConfig(
level=logging.DEBUG,
format=format,
datefmt="%Y-%m-%d %H:%M:%S",
force=True,
)
token_ids = [
[1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
[1, 2, 3, 44, 55, 66, 77, 88, 99, 100],
]
cache = NgramCache(branch_length=12, draft_token_num=8)
cache.batch_put(token_ids)
cache.synchronize()
decoding_ids, decoding_masks = cache.batch_get([[1, 2, 3], [3, 44], [3, 6, 999]])
cache.debug_result(decoding_ids, decoding_masks)

Xet Storage Details

Size:
4.38 kB
·
Xet hash:
d71774fff4f99b84a15f9fb96ce421e5007e0e5ce485ce98348212e5a3b74cd7

Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.