|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
import torch_npu |
|
|
from vllm.v1.sample.ops.topk_topp_sampler import TopKTopPSampler, random_sample |
|
|
from vllm.v1.sample.sampler import Sampler, _SAMPLING_EPS |
|
|
from vllm.v1.sample.metadata import SamplingMetadata |
|
|
from vllm_ascend import envs |
|
|
|
|
|
|
|
|
def apply_top_k_top_p( |
|
|
logits: torch.Tensor, |
|
|
k: torch.Tensor, |
|
|
p: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
if p is not None and k is not None: |
|
|
|
|
|
return torch_npu.npu_top_k_top_p(logits, p, k) |
|
|
|
|
|
probs = logits.softmax(dim=-1) |
|
|
probs_sort, _ = probs.sort(dim=-1, descending=False) |
|
|
|
|
|
if k is not None: |
|
|
top_k_count = probs_sort.size(1) - k.to(torch.long) |
|
|
top_k_count = top_k_count.unsqueeze(dim=1) |
|
|
top_k_cutoff = probs_sort.gather(-1, top_k_count) |
|
|
|
|
|
|
|
|
no_top_k_mask = (k == logits.shape[1]).unsqueeze(dim=1) |
|
|
top_k_cutoff.masked_fill_(no_top_k_mask, -float("inf")) |
|
|
|
|
|
elements_to_discard = probs < top_k_cutoff |
|
|
logits.masked_fill_(elements_to_discard, -float("inf")) |
|
|
|
|
|
if p is not None: |
|
|
cumprob = torch.cumsum(probs_sort, dim=-1) |
|
|
top_p_mask = cumprob <= 1 - p.unsqueeze(dim=1) |
|
|
top_p_mask[:, -1] = False |
|
|
|
|
|
top_p_count = top_p_mask.sum(dim=-1).unsqueeze(1) |
|
|
top_p_cutoff = probs_sort.gather(-1, top_p_count) |
|
|
elements_to_discard = probs < top_p_cutoff |
|
|
logits.masked_fill_(elements_to_discard, -float("inf")) |
|
|
|
|
|
return logits |
|
|
|
|
|
|
|
|
def topk_topp_forward_native( |
|
|
self, |
|
|
logits: torch.Tensor, |
|
|
generators: dict[int, torch.Generator], |
|
|
k: Optional[torch.Tensor], |
|
|
p: Optional[torch.Tensor], |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
PyTorch-native implementation of top-k and top-p sampling. |
|
|
|
|
|
The logits tensor may be updated in-place. |
|
|
""" |
|
|
logits = apply_top_k_top_p(logits, k, p) |
|
|
probs = logits.softmax(dim=-1, dtype=torch.float32) |
|
|
return random_sample(probs, generators) |
|
|
|
|
|
|
|
|
def apply_top_n_sigma( |
|
|
logits: torch.Tensor, |
|
|
sampling_metadata: SamplingMetadata, |
|
|
): |
|
|
if sampling_metadata.no_top_n_sigma: |
|
|
return logits |
|
|
|
|
|
top_n_sigma = sampling_metadata.top_n_sigma[:, None] |
|
|
top_n_sigma_mask = (top_n_sigma != -1) |
|
|
filter_value = -3.4028e+38 |
|
|
max_vals, _ = logits.max(dim=-1, keepdim=True) |
|
|
std_vals = logits.std(dim=-1, keepdim=True) |
|
|
threshold = max_vals - top_n_sigma * std_vals |
|
|
threshold[~top_n_sigma_mask] = filter_value |
|
|
mask = (logits < threshold) |
|
|
logits = torch.where(mask, filter_value, logits) |
|
|
return logits |
|
|
|
|
|
|
|
|
def sample( |
|
|
self, |
|
|
logits: torch.Tensor, |
|
|
sampling_metadata: SamplingMetadata, |
|
|
) -> torch.Tensor: |
|
|
"""Sample logits based on sampling metadata. |
|
|
|
|
|
The various logits processing functions called in this method |
|
|
may update the logits tensor in-place. |
|
|
""" |
|
|
|
|
|
assert not (sampling_metadata.all_greedy |
|
|
and sampling_metadata.all_random) |
|
|
if sampling_metadata.all_random: |
|
|
greedy_sampled = None |
|
|
else: |
|
|
greedy_sampled = self.greedy_sample(logits) |
|
|
if sampling_metadata.all_greedy: |
|
|
return greedy_sampled |
|
|
|
|
|
assert sampling_metadata.temperature is not None |
|
|
|
|
|
|
|
|
logits = self.apply_temperature(logits, sampling_metadata.temperature) |
|
|
|
|
|
|
|
|
|
|
|
for processor in sampling_metadata.logitsprocs.argmax_invariant: |
|
|
logits = processor.apply(logits) |
|
|
|
|
|
|
|
|
logits = apply_top_n_sigma(logits, sampling_metadata) |
|
|
|
|
|
|
|
|
random_sampled = self.topk_topp_sampler( |
|
|
logits, |
|
|
sampling_metadata.generators, |
|
|
sampling_metadata.top_k, |
|
|
sampling_metadata.top_p, |
|
|
) |
|
|
|
|
|
if greedy_sampled is None: |
|
|
return random_sampled |
|
|
|
|
|
sampled = torch.where( |
|
|
sampling_metadata.temperature < _SAMPLING_EPS, |
|
|
greedy_sampled, |
|
|
random_sampled, |
|
|
out=greedy_sampled, |
|
|
) |
|
|
return sampled |
|
|
|
|
|
|
|
|
if envs.VLLM_ASCEND_ENABLE_TOPK_TOPP_OPTIMIZATION: |
|
|
TopKTopPSampler.forward_native = topk_topp_forward_native |
|
|
|
|
|
if envs.VLLM_ASCEND_ENABLE_TOP_N_SIGMA: |
|
|
Sampler.sample = sample |