| from __future__ import annotations | |
| import logging | |
| import os | |
| import time | |
| from contextlib import contextmanager | |
| from typing import TYPE_CHECKING, List | |
| import torch | |
| import triton | |
| import triton.language as tl | |
| from huggingface_hub import snapshot_download | |
| from sglang.srt.constrained.base_grammar_backend import BaseGrammarObject | |
| from sglang.srt.distributed.parallel_state import ( | |
| GroupCoordinator, | |
| patch_tensor_parallel_group, | |
| ) | |
| from sglang.srt.environ import envs | |
| from sglang.srt.layers.logits_processor import LogitsProcessorOutput | |
| from sglang.srt.managers.schedule_batch import Req | |
| from sglang.srt.utils import is_cuda, is_hip | |
| if TYPE_CHECKING: | |
| from sglang.srt.speculative.eagle_info import EagleVerifyInput | |
| if is_cuda(): | |
| from sgl_kernel import fast_topk | |
| elif is_hip(): | |
| from sgl_kernel import fast_topk | |
| logger = logging.getLogger(__name__) | |
| # Simulate acceptance length for benchmarking purposes | |
| SIMULATE_ACC_LEN = envs.SGLANG_SIMULATE_ACC_LEN.get() # turn off if < 0 | |
| SIMULATE_ACC_METHOD = envs.SGLANG_SIMULATE_ACC_METHOD.get() | |
| TREE_TRAVERSE_TIME_THRESHOLD = 1 # TODO: set this properly | |
| TREE_SPEC_KERNEL_AVAILABLE = is_cuda() # This kernel is only available for CUDA now | |
| def create_extend_after_decode_spec_info( | |
| verified_id, | |
| seq_lens, | |
| accept_lens, | |
| positions, | |
| new_verified_id, | |
| bs_upper: tl.constexpr, | |
| ): | |
| pid = tl.program_id(axis=0) | |
| offsets = tl.arange(0, bs_upper) | |
| seq_length = tl.load(seq_lens + pid) | |
| accept_length = tl.load(accept_lens + pid) | |
| accept_len_cumsum = tl.sum( | |
| tl.load(accept_lens + offsets, mask=offsets < pid, other=0) | |
| ) | |
| positions_ptr = positions + accept_len_cumsum | |
| mask = offsets < accept_length | |
| tl.store(positions_ptr + offsets, seq_length - accept_length + offsets, mask) | |
| accept_len_cumsum += accept_length - 1 | |
| verified_id_data = tl.load(verified_id + accept_len_cumsum) | |
| tl.store(new_verified_id + pid, verified_id_data) | |
| def assign_req_to_token_pool( | |
| req_pool_indices, | |
| req_to_token, | |
| start_offset, | |
| end_offset, | |
| out_cache_loc, | |
| pool_len: tl.constexpr, | |
| bs_upper: tl.constexpr, | |
| ): | |
| BLOCK_SIZE: tl.constexpr = 32 | |
| pid = tl.program_id(axis=0) | |
| kv_start = tl.load(start_offset + pid) | |
| kv_end = tl.load(end_offset + pid) | |
| token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len | |
| length_offset = tl.arange(0, bs_upper) | |
| start = tl.load(start_offset + length_offset, mask=length_offset < pid, other=0) | |
| end = tl.load(end_offset + length_offset, mask=length_offset < pid, other=0) | |
| out_offset = tl.sum(end - start, axis=0) | |
| out_cache_ptr = out_cache_loc + out_offset | |
| save_offset = tl.arange(0, BLOCK_SIZE) + kv_start | |
| load_offset = tl.arange(0, BLOCK_SIZE) | |
| num_loop = tl.cdiv(kv_end - kv_start, BLOCK_SIZE) | |
| for _ in range(num_loop): | |
| mask = save_offset < kv_end | |
| data = tl.load(out_cache_ptr + load_offset, mask=mask) | |
| tl.store(token_pool + save_offset, data, mask=mask) | |
| save_offset += BLOCK_SIZE | |
| load_offset += BLOCK_SIZE | |
| def assign_draft_cache_locs( | |
| req_pool_indices, | |
| req_to_token, | |
| seq_lens, | |
| extend_lens, | |
| num_new_pages_per_topk, | |
| out_cache_loc, | |
| pool_len: tl.constexpr, | |
| topk: tl.constexpr, | |
| speculative_num_steps: tl.constexpr, | |
| page_size: tl.constexpr, | |
| bs_upper: tl.constexpr, | |
| iter_upper: tl.constexpr, | |
| ): | |
| BLOCK_SIZE: tl.constexpr = 128 | |
| pid = tl.program_id(axis=0) | |
| if page_size == 1 or topk == 1: | |
| copy_len = topk * speculative_num_steps | |
| out_cache_ptr = out_cache_loc + pid * topk * speculative_num_steps | |
| else: | |
| bs_offset = tl.arange(0, bs_upper) | |
| copy_len = tl.load(extend_lens + pid) | |
| cum_copy_len = tl.sum(tl.load(extend_lens + bs_offset, mask=bs_offset < pid)) | |
| out_cache_ptr = out_cache_loc + cum_copy_len | |
| # Part 1: Copy from out_cache_loc to req_to_token | |
| kv_start = tl.load(seq_lens + pid) | |
| token_pool = req_to_token + tl.load(req_pool_indices + pid) * pool_len | |
| num_loop = tl.cdiv(copy_len, BLOCK_SIZE) | |
| for i in range(num_loop): | |
| copy_offset = tl.arange(0, BLOCK_SIZE) + i * BLOCK_SIZE | |
| mask = copy_offset < copy_len | |
| data = tl.load(out_cache_ptr + copy_offset, mask=mask) | |
| tl.store(token_pool + kv_start + copy_offset, data, mask=mask) | |
| if page_size == 1 or topk == 1: | |
| return | |
| # Part 2: Copy the indices for the last partial page | |
| prefix_len = tl.load(seq_lens + pid) | |
| last_page_len = prefix_len % page_size | |
| offsets = tl.arange(0, page_size) | |
| mask = offsets < last_page_len | |
| num_new_pages_per_topk_ = tl.load(num_new_pages_per_topk + pid) | |
| prefix_base = token_pool + prefix_len - last_page_len | |
| for topk_id in range(topk): | |
| value = tl.load(prefix_base + offsets, mask=mask) | |
| tl.store( | |
| prefix_base + topk_id * num_new_pages_per_topk_ * page_size + offsets, | |
| value, | |
| mask=mask, | |
| ) | |
| # Part 3: Remove the padding in out_cache_loc | |
| iter_offest = tl.arange(0, iter_upper) | |
| for topk_id in range(topk): | |
| indices = tl.load( | |
| prefix_base | |
| + topk_id * num_new_pages_per_topk_ * page_size | |
| + last_page_len | |
| + iter_offest, | |
| mask=iter_offest < speculative_num_steps, | |
| ) | |
| tl.store( | |
| out_cache_loc | |
| + pid * topk * speculative_num_steps | |
| + topk_id * speculative_num_steps | |
| + iter_offest, | |
| indices, | |
| mask=iter_offest < speculative_num_steps, | |
| ) | |
| def generate_draft_decode_kv_indices( | |
| req_pool_indices, | |
| req_to_token, | |
| paged_kernel_lens, | |
| kv_indices, | |
| kv_indptr, | |
| positions, | |
| pool_len: tl.constexpr, | |
| kv_indices_stride: tl.constexpr, | |
| kv_indptr_stride: tl.constexpr, | |
| bs_upper: tl.constexpr, | |
| iter_upper: tl.constexpr, | |
| num_tokens_upper: tl.constexpr, | |
| page_size: tl.constexpr, | |
| ): | |
| BLOCK_SIZE: tl.constexpr = 128 | |
| iters = tl.program_id(axis=0) | |
| bid = tl.program_id(axis=1) | |
| topk_id = tl.program_id(axis=2) | |
| num_steps = tl.num_programs(axis=0) | |
| num_seqs = tl.num_programs(axis=1) | |
| topk = tl.num_programs(axis=2) | |
| kv_indices += kv_indices_stride * iters | |
| kv_indptr += kv_indptr_stride * iters | |
| iters += 1 | |
| load_offset = tl.arange(0, bs_upper) | |
| seq_lens = tl.load(paged_kernel_lens + load_offset, mask=load_offset < bid, other=0) | |
| seq_len = tl.load(paged_kernel_lens + bid) | |
| cum_seq_len = tl.sum(seq_lens) | |
| # Update kv_indices | |
| kv_offset = cum_seq_len * topk + bid * iters * topk + topk_id * (seq_len + iters) | |
| kv_ptr = kv_indices + kv_offset | |
| token_pool_ptr = req_to_token + tl.load(req_pool_indices + bid) * pool_len | |
| kv_offset = tl.arange(0, BLOCK_SIZE) | |
| num_loop = tl.cdiv(seq_len, BLOCK_SIZE) | |
| for _ in range(num_loop): | |
| mask = kv_offset < seq_len | |
| data = tl.load(token_pool_ptr + kv_offset, mask=mask) | |
| tl.store(kv_ptr + kv_offset, data, mask=mask) | |
| kv_offset += BLOCK_SIZE | |
| extend_offset = tl.arange(0, iter_upper) | |
| if page_size == 1 or topk == 1: | |
| extend_data = tl.load( | |
| token_pool_ptr + seq_len + topk_id * num_steps + tl.arange(0, iter_upper), | |
| mask=extend_offset < iters, | |
| ) | |
| else: | |
| prefix_len = seq_len | |
| last_page_len = prefix_len % page_size | |
| num_new_pages_per_topk = ( | |
| last_page_len + num_steps + page_size - 1 | |
| ) // page_size | |
| prefix_base = seq_len // page_size * page_size | |
| start = ( | |
| prefix_base + topk_id * num_new_pages_per_topk * page_size + last_page_len | |
| ) | |
| extend_data = tl.load( | |
| token_pool_ptr + start + extend_offset, | |
| mask=extend_offset < iters, | |
| ) | |
| tl.store(kv_ptr + seq_len + extend_offset, extend_data, mask=extend_offset < iters) | |
| # Update kv_indptr | |
| bs_offset = tl.arange(0, num_tokens_upper) | |
| zid = bid * topk + topk_id | |
| if zid == 0: | |
| zid = num_seqs * topk | |
| positions = tl.load(positions + bs_offset, mask=bs_offset < zid, other=0) | |
| base = tl.sum(positions) | |
| tl.store(kv_indptr + zid, base + zid * iters) | |
| def align_evict_mask_to_page_size( | |
| seq_lens, | |
| evict_mask, | |
| page_size: tl.constexpr, | |
| num_draft_tokens: tl.constexpr, | |
| BLOCK_SIZE: tl.constexpr, | |
| ): | |
| t_range = tl.arange(0, BLOCK_SIZE) | |
| bid = tl.program_id(axis=0) | |
| seq_len = tl.load(seq_lens + bid) | |
| io_mask = t_range < num_draft_tokens | |
| mask_row = tl.load( | |
| evict_mask + bid * num_draft_tokens + t_range, mask=io_mask, other=0 | |
| ) | |
| num_trues = tl.sum(mask_row) | |
| num_false = num_draft_tokens - num_trues | |
| start = (seq_len + num_false - 1) // page_size * page_size - seq_len | |
| for i in range(max(start, 0), min(start + page_size, num_draft_tokens)): | |
| tl.store(evict_mask + bid * num_draft_tokens + i, False) | |
| def get_target_cache_loc( | |
| tgt_cache_loc, | |
| to_free_slots, | |
| accept_length, | |
| to_free_num_slots, | |
| out_cache_loc, | |
| num_verify_tokens: tl.constexpr, | |
| num_verify_tokens_upper: tl.constexpr, | |
| bs_upper: tl.constexpr, | |
| ): | |
| bid = tl.program_id(axis=0) | |
| offset = tl.arange(0, num_verify_tokens_upper) | |
| bs_offset = tl.arange(0, bs_upper) | |
| # write the first part to tgt_cache_loc | |
| accept_len_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) | |
| tgt_cache_loc_start = tl.sum(accept_len_all) + bid | |
| copy_len = tl.load(accept_length + bid) + 1 | |
| out_cache_loc_row = tl.load( | |
| out_cache_loc + bid * num_verify_tokens + offset, mask=offset < copy_len | |
| ) | |
| tl.store( | |
| tgt_cache_loc + tgt_cache_loc_start + offset, | |
| out_cache_loc_row, | |
| mask=offset < copy_len, | |
| ) | |
| # write the second part to to_free_num_pages | |
| to_free_num_slots_all = tl.load(to_free_num_slots + bs_offset, mask=bs_offset < bid) | |
| to_free_num_slots_cur = tl.load(to_free_num_slots + bid) | |
| out_cache_loc_start = num_verify_tokens - to_free_num_slots_cur | |
| to_free_slots_start = tl.sum(to_free_num_slots_all) | |
| copy_len = to_free_num_slots_cur | |
| out_cache_loc_row = tl.load( | |
| out_cache_loc + bid * num_verify_tokens + out_cache_loc_start + offset, | |
| mask=offset < copy_len, | |
| ) | |
| tl.store( | |
| to_free_slots + to_free_slots_start + offset, | |
| out_cache_loc_row, | |
| mask=offset < copy_len, | |
| ) | |
| def get_src_tgt_cache_loc( | |
| seq_lens: torch.Tensor, | |
| out_cache_loc: torch.Tensor, | |
| accept_index: torch.Tensor, | |
| accept_length: torch.Tensor, | |
| draft_token_num: int, | |
| page_size: int, | |
| ): | |
| src_cache_loc = out_cache_loc[accept_index] | |
| tgt_cache_loc = torch.empty_like(src_cache_loc) | |
| extended_len = seq_lens + draft_token_num | |
| keep_len = torch.minimum( | |
| (seq_lens + accept_length + 1 + page_size - 1) // page_size * page_size, | |
| extended_len, | |
| ) | |
| to_free_num_slots = extended_len - keep_len | |
| return src_cache_loc, tgt_cache_loc, to_free_num_slots | |
| def filter_finished_cache_loc_kernel( | |
| out_cache_loc, | |
| tgt_cache_loc, | |
| accept_length, | |
| accept_length_filter, | |
| bs_upper: tl.constexpr, | |
| num_verify_tokens_upper: tl.constexpr, | |
| ): | |
| bid = tl.program_id(0) | |
| bs_offset = tl.arange(0, bs_upper) | |
| accept_length_all = tl.load(accept_length + bs_offset, mask=bs_offset < bid) | |
| old_start = tl.sum(accept_length_all) + bid | |
| accept_length_filter_all = tl.load( | |
| accept_length_filter + bs_offset, mask=bs_offset < bid | |
| ) | |
| new_start = tl.sum(accept_length_filter_all) | |
| copy_len = tl.load(accept_length_filter + bid) | |
| copy_offset = tl.arange(0, num_verify_tokens_upper) | |
| value = tl.load( | |
| tgt_cache_loc + old_start + copy_offset, mask=copy_offset < copy_len | |
| ) | |
| tl.store( | |
| out_cache_loc + new_start + copy_offset, value, mask=copy_offset < copy_len | |
| ) | |
| def create_accept_length_filter( | |
| accept_length: torch.Tensor, | |
| unfinished_index_device: torch.Tensor, | |
| seq_lens: torch.Tensor, | |
| ): | |
| accept_length_filter = torch.zeros_like(accept_length) | |
| accept_length_filter[unfinished_index_device] = ( | |
| accept_length[unfinished_index_device] + 1 | |
| ) | |
| seq_lens.add_(accept_length + 1) | |
| return accept_length_filter | |
| def select_top_k_tokens( | |
| i: int, | |
| topk_p: torch.Tensor, | |
| topk_index: torch.Tensor, | |
| hidden_states: torch.Tensor, | |
| scores: torch.Tensor, | |
| topk: int, | |
| ): | |
| if i == 0: | |
| # The first step after extend | |
| input_ids = topk_index.flatten() | |
| hidden_states = hidden_states.repeat_interleave(topk, dim=0) | |
| scores = topk_p # shape: (b, topk) | |
| tree_info = ( | |
| topk_p.unsqueeze(1), # shape: (b, 1, topk) | |
| topk_index, # shape: (b, topk) | |
| torch.arange(-1, topk, dtype=torch.long, device="cuda") | |
| .unsqueeze(0) | |
| .repeat(topk_p.shape[0], 1), # shape: (b, topk + 1) | |
| ) | |
| else: | |
| # The later decode steps | |
| expand_scores = torch.mul( | |
| scores.unsqueeze(2), topk_p.reshape(-1, topk, topk) | |
| ) # (b, topk, 1) x (b, topk ,topk) -> (b, topk, topk) | |
| topk_cs_p, topk_cs_index = fast_topk( | |
| expand_scores.flatten(start_dim=1), topk, dim=-1 | |
| ) # (b, topk) | |
| scores = topk_cs_p # shape: (b, topk) | |
| topk_index = topk_index.reshape(-1, topk**2) | |
| input_ids = torch.gather(topk_index, index=topk_cs_index, dim=1).flatten() | |
| if hidden_states.shape[0] > 0: | |
| selected_input_index = topk_cs_index.flatten() // topk + torch.arange( | |
| 0, hidden_states.shape[0], step=topk, device="cuda" | |
| ).repeat_interleave(topk) | |
| hidden_states = hidden_states[selected_input_index, :] | |
| tree_info = ( | |
| expand_scores, # shape: (b, topk, topk) | |
| topk_index, # shape: (b, topk * topk) | |
| topk_cs_index + (topk**2 * (i - 1) + topk), # shape: (b, topk) | |
| ) | |
| return input_ids, hidden_states, scores, tree_info | |
| def generate_simulated_accept_index( | |
| accept_index, | |
| predict, | |
| accept_length, | |
| bs, | |
| spec_steps, | |
| simulate_acc_len: float = SIMULATE_ACC_LEN, | |
| simulate_acc_method: str = SIMULATE_ACC_METHOD, | |
| ): | |
| assert simulate_acc_len > 0.0 | |
| if simulate_acc_method == "multinomial": | |
| simulated_values = torch.normal( | |
| mean=simulate_acc_len, | |
| std=1.0, | |
| size=(1,), | |
| device="cpu", | |
| ) | |
| # clamp simulated values to be between 1 and self.spec_steps | |
| simulated_values = torch.clamp(simulated_values, min=1.0, max=spec_steps + 1) | |
| simulate_acc_len = int(simulated_values.round().item()) | |
| elif simulate_acc_method == "match-expected": | |
| # multinomial sampling does not match the expected length | |
| # we keep it for the sake of compatibility of existing tests | |
| # but it's better to use "match-expected" for the cases that need to | |
| # match the expected length, One caveat is that this will only sample | |
| # either round down or round up of the expected length | |
| simulate_acc_len = max(1.0, min(spec_steps + 1, simulate_acc_len)) | |
| lower = int(simulate_acc_len // 1) | |
| upper = lower + 1 if lower < spec_steps + 1 else lower | |
| if lower == upper: | |
| simulate_acc_len = lower | |
| else: | |
| weight_upper = simulate_acc_len - lower | |
| weight_lower = 1.0 - weight_upper | |
| probs = torch.tensor([weight_lower, weight_upper], device="cpu") | |
| sampled_index = torch.multinomial(probs, num_samples=1) | |
| simulate_acc_len = lower if sampled_index == 0 else upper | |
| else: | |
| raise ValueError(f"Invalid simulate_acc_method: {SIMULATE_ACC_METHOD}") | |
| accept_indx_first_col = accept_index[:, 0].view(-1, 1) | |
| sim_accept_index = torch.full( | |
| (bs, spec_steps + 1), -1, dtype=torch.int32, device="cuda" | |
| ) | |
| sim_accept_index[:, :simulate_acc_len] = accept_indx_first_col + torch.arange( | |
| simulate_acc_len, device=accept_index.device | |
| ) | |
| accept_length.fill_(simulate_acc_len - 1) | |
| predict.fill_(100) # some legit token id | |
| return sim_accept_index | |
| def traverse_tree( | |
| retrieve_next_token: torch.Tensor, | |
| retrieve_next_sibling: torch.Tensor, | |
| draft_tokens: torch.Tensor, | |
| grammar: BaseGrammarObject, | |
| allocate_token_bitmask: torch.Tensor, | |
| ): | |
| """ | |
| Traverse the tree constructed by the draft model to generate the logits mask. | |
| """ | |
| assert ( | |
| retrieve_next_token.shape == retrieve_next_sibling.shape == draft_tokens.shape | |
| ) | |
| allocate_token_bitmask.fill_(0) | |
| def dfs( | |
| curr: int, | |
| retrieve_next_token: torch.Tensor, | |
| retrieve_next_sibling: torch.Tensor, | |
| parent_pos: int, | |
| ): | |
| if curr == 0: | |
| # the first token generated by the target model, and thus it is always | |
| # accepted from the previous iteration | |
| accepted = True | |
| else: | |
| parent_bitmask = allocate_token_bitmask[parent_pos] | |
| curr_token_id = draft_tokens[curr] | |
| # 32 boolean bitmask values are packed into 32-bit integers | |
| accepted = ( | |
| parent_bitmask[curr_token_id // 32] & (1 << (curr_token_id % 32)) | |
| ) != 0 | |
| if accepted: | |
| if curr != 0: | |
| # Accept the current token | |
| grammar.accept_token(draft_tokens[curr]) | |
| if not grammar.is_terminated(): | |
| # Generate the bitmask for the current token | |
| grammar.fill_vocab_mask(allocate_token_bitmask, curr) | |
| if retrieve_next_token[curr] != -1: | |
| # Visit the child node | |
| dfs( | |
| retrieve_next_token[curr], | |
| retrieve_next_token, | |
| retrieve_next_sibling, | |
| curr, | |
| ) | |
| if curr != 0: | |
| # Rollback the current token | |
| grammar.rollback(1) | |
| if retrieve_next_sibling[curr] != -1: | |
| # Visit the sibling node | |
| dfs( | |
| retrieve_next_sibling[curr], | |
| retrieve_next_token, | |
| retrieve_next_sibling, | |
| parent_pos, | |
| ) | |
| dfs(0, retrieve_next_token, retrieve_next_sibling, -1) | |
| def generate_token_bitmask( | |
| reqs: List[Req], | |
| verify_input: EagleVerifyInput, | |
| retrieve_next_token_cpu: torch.Tensor, | |
| retrieve_next_sibling_cpu: torch.Tensor, | |
| draft_tokens_cpu: torch.Tensor, | |
| vocab_size: int, | |
| ): | |
| """ | |
| Generate the logit mask for structured output. | |
| Draft model's token can be either valid or invalid with respect to the grammar. | |
| We need to perform DFS to | |
| 1. figure out which tokens are accepted by the grammar. | |
| 2. if so, what is the corresponding logit mask. | |
| """ | |
| num_draft_tokens = draft_tokens_cpu.shape[-1] | |
| allocate_token_bitmask = None | |
| assert len(reqs) == retrieve_next_token_cpu.shape[0] | |
| grammar = None | |
| for i, req in enumerate(reqs): | |
| if req.grammar is not None: | |
| if allocate_token_bitmask is None: | |
| allocate_token_bitmask = req.grammar.allocate_vocab_mask( | |
| vocab_size=vocab_size, | |
| batch_size=draft_tokens_cpu.numel(), | |
| device="cpu", | |
| ) | |
| grammar = req.grammar | |
| s = time.perf_counter() | |
| traverse_tree( | |
| retrieve_next_token_cpu[i], | |
| retrieve_next_sibling_cpu[i], | |
| draft_tokens_cpu[i], | |
| req.grammar, | |
| allocate_token_bitmask[ | |
| i * num_draft_tokens : (i + 1) * num_draft_tokens | |
| ], | |
| ) | |
| tree_traverse_time = time.perf_counter() - s | |
| if tree_traverse_time > TREE_TRAVERSE_TIME_THRESHOLD: | |
| logger.warning( | |
| f"Bit mask generation took {tree_traverse_time} seconds with " | |
| f"grammar: {req.grammar}" | |
| ) | |
| verify_input.grammar = grammar | |
| return allocate_token_bitmask | |
| def load_token_map(token_map_path: str) -> List[int]: | |
| if not os.path.exists(token_map_path): | |
| cache_dir = snapshot_download( | |
| os.path.dirname(token_map_path), | |
| ignore_patterns=["*.bin", "*.safetensors"], | |
| ) | |
| token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path)) | |
| hot_token_id = torch.load(token_map_path, weights_only=True) | |
| return torch.tensor(hot_token_id, dtype=torch.int64) | |
| def draft_tp_context(tp_group: GroupCoordinator): | |
| # Draft model doesn't use dp and has its own tp group. | |
| # We disable mscclpp now because it doesn't support 2 comm groups. | |
| with patch_tensor_parallel_group(tp_group): | |
| yield | |
| def detect_nan(logits_output: LogitsProcessorOutput): | |
| logits = logits_output.next_token_logits | |
| if torch.any(torch.isnan(logits)): | |
| logger.error("Detected errors during sampling! NaN in the logits.") | |
| raise ValueError("Detected errors during sampling! NaN in the logits.") | |
Xet Storage Details
- Size:
- 21.3 kB
- Xet hash:
- 47657a9b9dc444f1eae60d0e07951b34a7a587fbd7fd9b5aa7a39fe2224c5024
·
Xet efficiently stores files, intelligently splitting them into unique chunks and accelerating uploads and downloads. More info.